Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 37 additions & 133 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ const (
// read or write a packet while one is already used.
ephemeralUnused = iota

// ephemeralWriteGlobalBuffer means conn.buffer was used to write
// a packet. The first four bytes contain size and sequence.
ephemeralWriteGlobalBuffer

// ephemeralWriteSingleBuffer means a single buffer was
// allocated to write a packet. It is in
// c.currentEphemeralWriteBuffer. The first four bytes contain size
Expand All @@ -55,10 +51,6 @@ const (
// The allocated buffer is in c.currentEphemeralWriteBuffer.
ephemeralWriteBigBuffer

// ephemeralReadGlobalBuffer means conn.buffer was used for reading
// an ephemeral packet.
ephemeralReadGlobalBuffer

// ephemeralReadSingleBuffer means we are using a pool of buffers
// for reading.
ephemeralReadSingleBuffer
Expand Down Expand Up @@ -155,27 +147,6 @@ type Conn struct {
// fields, this is set to an empty array (but not nil).
fields []*querypb.Field

// Internal buffer for zero-allocation reads and writes. This
// uses the fact that both sides of a connection either read
// packets, or write packets, but never do both, and both
// sides know who is expected to read or write a packet next.
//
// Reading side: if the next expected packet will most likely be
// small, and we don't need to hand on to the memory after reading
// the packet, use readEphemeralPacket instead of readPacket.
// If the packet is too big, it will revert to the usual read.
// But if the packet is smaller than connBufferSize, this buffer
// will be used instead.
//
// Writing side: if the next packet to write is smaller than
// connBufferSize-4, this buffer can be used to create a
// packet. It will contain both the size and sequence header,
// and the contents of the packet.
// Call startEphemeralPacket(length) to get a buffer. If length
// is smaller or equal than connBufferSize-4, this buffer will be used.
// Otherwise memory will be allocated for it.
buffer []byte

// Keep track of how and of the buffer we allocated for an
// ephemeral packet on the read and write sides.
// These fields are used by:
Expand All @@ -201,83 +172,27 @@ func newConn(conn net.Conn) *Conn {
reader: bufio.NewReaderSize(conn, connBufferSize),
writer: bufio.NewWriterSize(conn, connBufferSize),
sequence: 0,
buffer: make([]byte, connBufferSize),
}
}

// readPacketDirect attempts to read a packet from the socket directly.
// It needs to be used for the first handshake packet the server receives,
// so we do't buffer the SSL negotiation packet. As a shortcut, only
// packets smaller than MaxPacketSize can be read here.
func (c *Conn) readPacketDirect() ([]byte, error) {
var header [4]byte
if _, err := io.ReadFull(c.conn, header[:]); err != nil {
// Propagate as is so server can ignore this kind of error
// Same as readEphemeralPacket()
if err == io.EOF {
return nil, err
}
// Treat connection reset by peer as io.EOF, otherwise is too spammy.
if strings.HasSuffix(err.Error(), "read: connection reset by peer") {
return nil, io.EOF
}
return nil, fmt.Errorf("io.ReadFull(header size) failed: %v", err)
}

sequence := uint8(header[3])
if sequence != c.sequence {
return nil, fmt.Errorf("invalid sequence, expected %v got %v", c.sequence, sequence)
}

c.sequence++

length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
if length <= cap(c.buffer) {
// Fast path: read into buffer, we're good.
c.buffer = c.buffer[:length]
if _, err := io.ReadFull(c.conn, c.buffer); err != nil {
return nil, fmt.Errorf("io.ReadFull(direct packet body of length %v) failed: %v", length, err)
}
return c.buffer, nil
}

// Sanity check
if length == MaxPacketSize {
return nil, fmt.Errorf("readPacketDirect doesn't support more than one packet")
}

// Slow path, revert to allocating.
data := make([]byte, length)
if _, err := io.ReadFull(c.conn, data); err != nil {
return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err)
}
return data, nil
}

// readEphemeralPacket attempts to read a packet into c.buffer. Do
// not use this method if the contents of the packet needs to be kept
// after the next readEphemeralPacket. If the packet is bigger than
// connBufferSize, we revert to using the same behavior as a regular
// readPacket. recycleReadPacket() has to be called after this method
// is used, and before we read or write any other packet on the connection.
//
// Note if the connection is closed already, an error will be
// returned, and it may not be io.EOF. If the connection closes while
// we are stuck waiting for data, an error will also be returned, and
// it most likely will be io.EOF.
func (c *Conn) readEphemeralPacket() ([]byte, error) {
func (c *Conn) readEphemeralPacketHelper(direct bool) ([]byte, error) {
if c.currentEphemeralPolicy != ephemeralUnused {
panic(fmt.Errorf("readEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy))
}

var r io.Reader = c.reader
if direct {
r = c.conn
}

// Note io.ReadFull will return two different types of errors:
// 1. if the socket is already closed, and the go runtime knows it,
// then ReadFull will return an error (different than EOF),
// someting like 'read: connection reset by peer'.
// 2. if the socket is not closed while we start the read,
// but gets closed after the read is started, we'll get io.EOF.
var header [4]byte
if _, err := io.ReadFull(c.reader, header[:]); err != nil {
if _, err := io.ReadFull(r, header[:]); err != nil {
// The special casing of propagating io.EOF up
// is used by the server side only, to suppress an error
// message if a client just disconnects.
Expand All @@ -303,26 +218,21 @@ func (c *Conn) readEphemeralPacket() ([]byte, error) {
// exactly size MaxPacketSize.
return nil, nil
}
if length <= cap(c.buffer) {
// Fast path: read into buffer, we're good.
c.currentEphemeralPolicy = ephemeralReadGlobalBuffer
c.buffer = c.buffer[:length]
if _, err := io.ReadFull(c.reader, c.buffer); err != nil {
return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err)
}
return c.buffer, nil
}

// Slightly slower path: single packet. Use the bufPool.
// Use the bufPool.
if length < MaxPacketSize {
c.currentEphemeralPolicy = ephemeralReadSingleBuffer
c.currentEphemeralReadBuffer = bufPool.Get(length)
if _, err := io.ReadFull(c.reader, *c.currentEphemeralReadBuffer); err != nil {
if _, err := io.ReadFull(r, *c.currentEphemeralReadBuffer); err != nil {
return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err)
}
return *c.currentEphemeralReadBuffer, nil
}

if direct {
return nil, fmt.Errorf("readEphemeralPacketDirect doesn't support more than one packet")
}

// Much slower path, revert to allocating everything from scratch.
// We're going to concatenate a lot of data anyway, can't really
// optimize this code path easily.
Expand Down Expand Up @@ -351,12 +261,30 @@ func (c *Conn) readEphemeralPacket() ([]byte, error) {
return data, nil
}

// readEphemeralPacketDirect attempts to read a packet from the socket directly.
// It needs to be used for the first handshake packet the server receives,
// so we do't buffer the SSL negotiation packet. As a shortcut, only
// packets smaller than MaxPacketSize can be read here.
func (c *Conn) readEphemeralPacketDirect() ([]byte, error) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder if it makes sense to put "finisher" functions as return values to these? That way it prevents people from forgetting to recycle. same with the write side.

return c.readEphemeralPacketHelper(true)
}

// readEphemeralPacket attempts to read a packet into buffer from sync.Pool. Do
// not use this method if the contents of the packet needs to be kept
// after the next readEphemeralPacket.
//
// Note if the connection is closed already, an error will be
// returned, and it may not be io.EOF. If the connection closes while
// we are stuck waiting for data, an error will also be returned, and
// it most likely will be io.EOF.
func (c *Conn) readEphemeralPacket() ([]byte, error) {
return c.readEphemeralPacketHelper(false)
}

// recycleReadPacket recycles the read packet. It needs to be called
// after readEphemeralPacket was called.
func (c *Conn) recycleReadPacket() {
switch c.currentEphemeralPolicy {
case ephemeralReadGlobalBuffer:
// We used small built-in buffer, nothing to do.
case ephemeralReadSingleBuffer:
// We are using the pool, put the buffer back in.
bufPool.Put(c.currentEphemeralReadBuffer)
Expand All @@ -365,7 +293,7 @@ func (c *Conn) recycleReadPacket() {
// We allocated a one-time buffer we can't re-use.
// Nothing to do. Nil out for safety.
c.currentEphemeralReadBuffer = nil
case ephemeralUnused, ephemeralWriteGlobalBuffer, ephemeralWriteSingleBuffer, ephemeralWriteBigBuffer:
case ephemeralUnused, ephemeralWriteSingleBuffer, ephemeralWriteBigBuffer:
// Programming error.
panic(fmt.Errorf("trying to call recycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
}
Expand Down Expand Up @@ -515,20 +443,7 @@ func (c *Conn) startEphemeralPacket(length int) []byte {
panic("startEphemeralPacket cannot be used while a packet is already started.")
}

// Fast path: we can reuse a single memory buffer for
// both the header and the data.
if length <= cap(c.buffer)-4 {
c.currentEphemeralPolicy = ephemeralWriteGlobalBuffer
c.buffer = c.buffer[:length+4]
c.buffer[0] = byte(length)
c.buffer[1] = byte(length >> 8)
c.buffer[2] = byte(length >> 16)
c.buffer[3] = c.sequence
c.sequence++
return c.buffer[4:]
}

// Slower path: we can use a single buffer for both the header and the data, but it has to be allocated.
// get buffer from pool
if length < MaxPacketSize {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't need this length check anymore. same with read.

and also BigBuffer can go away (or whatever the last "policy" is)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True for writes.
but for the reads code is different for large buffers - it reads packets one by one instead of ReadAll.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah okay. feel free to leave as is

c.currentEphemeralPolicy = ephemeralWriteSingleBuffer

Expand Down Expand Up @@ -560,14 +475,6 @@ func (c *Conn) writeEphemeralPacket(direct bool) error {
}

switch c.currentEphemeralPolicy {
case ephemeralWriteGlobalBuffer:
// Just write c.buffer as a single buffer.
// It has both header and data.
if n, err := w.Write(c.buffer); err != nil {
return fmt.Errorf("Conn %v: Write(c.buffer) failed: %v", c.ID(), err)
} else if n != len(c.buffer) {
return fmt.Errorf("Conn %v: Write(c.buffer) returned a short write: %v < %v", c.ID(), n, len(c.buffer))
}
case ephemeralWriteSingleBuffer:
// Write the allocated buffer as a single buffer.
// It has both header and data.
Expand All @@ -586,7 +493,7 @@ func (c *Conn) writeEphemeralPacket(direct bool) error {
if direct {
return c.flush()
}
case ephemeralUnused, ephemeralReadGlobalBuffer, ephemeralReadSingleBuffer, ephemeralReadBigBuffer:
case ephemeralUnused, ephemeralReadSingleBuffer, ephemeralReadBigBuffer:
// Programming error.
panic(fmt.Errorf("Conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.ID(), c.currentEphemeralPolicy))
}
Expand All @@ -598,8 +505,6 @@ func (c *Conn) writeEphemeralPacket(direct bool) error {
// after writeEphemeralPacket was called.
func (c *Conn) recycleWritePacket() {
switch c.currentEphemeralPolicy {
case ephemeralWriteGlobalBuffer:
// We used small built-in buffer, nothing to do.
case ephemeralWriteSingleBuffer:
// Release our reference so the buffer can be gced
bufPool.Put(c.currentEphemeralWriteBuffer)
Expand All @@ -609,8 +514,7 @@ func (c *Conn) recycleWritePacket() {
// N.B. Unlike the read packet, we actually assign the big buffer to currentEphemeralReadBuffer,
// so we should remove our reference to it.
c.currentEphemeralWriteBuffer = nil
case ephemeralUnused, ephemeralReadGlobalBuffer,
ephemeralReadSingleBuffer, ephemeralReadBigBuffer:
case ephemeralUnused, ephemeralReadSingleBuffer, ephemeralReadBigBuffer:
// Programming error.
panic(fmt.Errorf("trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy))
}
Expand Down
11 changes: 7 additions & 4 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,14 @@ func verifyPacketComms(t *testing.T, cConn, sConn *Conn, data []byte) {
verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacketDirect, sConn.readEphemeralPacket)
sConn.recycleReadPacket()

// All three writes, with readPacketDirect, if size allows it.
// All three writes, with readEphemeralPacketDirect, if size allows it.
if len(data) < MaxPacketSize {
verifyPacketCommsSpecific(t, cConn, data, useWritePacket, sConn.readPacketDirect)
verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacket, sConn.readPacketDirect)
verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacketDirect, sConn.readPacketDirect)
verifyPacketCommsSpecific(t, cConn, data, useWritePacket, sConn.readEphemeralPacketDirect)
sConn.recycleReadPacket()
verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacket, sConn.readEphemeralPacketDirect)
sConn.recycleReadPacket()
verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacketDirect, sConn.readEphemeralPacketDirect)
sConn.recycleReadPacket()
}
}

Expand Down
4 changes: 3 additions & 1 deletion go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti

// Wait for the client response. This has to be a direct read,
// so we don't buffer the TLS negotiation packets.
response, err := c.readPacketDirect()
response, err := c.readEphemeralPacketDirect()
if err != nil {
// Don't log EOF errors. They cause too much spam, same as main read loop.
if err != io.EOF {
Expand All @@ -218,6 +218,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
return
}

c.recycleReadPacket()

if c.Capabilities&CapabilityClientSSL > 0 {
// SSL was enabled. We need to re-read the auth packet.
response, err = c.readEphemeralPacket()
Expand Down