diff --git a/go/mysql/conn.go b/go/mysql/conn.go index f5f4d48e637..3cea9e60293 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -40,10 +40,6 @@ const ( // read or write a packet while one is already used. ephemeralUnused = iota - // ephemeralWriteGlobalBuffer means conn.buffer was used to write - // a packet. The first four bytes contain size and sequence. - ephemeralWriteGlobalBuffer - // ephemeralWriteSingleBuffer means a single buffer was // allocated to write a packet. It is in // c.currentEphemeralWriteBuffer. The first four bytes contain size @@ -55,10 +51,6 @@ const ( // The allocated buffer is in c.currentEphemeralWriteBuffer. ephemeralWriteBigBuffer - // ephemeralReadGlobalBuffer means conn.buffer was used for reading - // an ephemeral packet. - ephemeralReadGlobalBuffer - // ephemeralReadSingleBuffer means we are using a pool of buffers // for reading. ephemeralReadSingleBuffer @@ -155,27 +147,6 @@ type Conn struct { // fields, this is set to an empty array (but not nil). fields []*querypb.Field - // Internal buffer for zero-allocation reads and writes. This - // uses the fact that both sides of a connection either read - // packets, or write packets, but never do both, and both - // sides know who is expected to read or write a packet next. - // - // Reading side: if the next expected packet will most likely be - // small, and we don't need to hand on to the memory after reading - // the packet, use readEphemeralPacket instead of readPacket. - // If the packet is too big, it will revert to the usual read. - // But if the packet is smaller than connBufferSize, this buffer - // will be used instead. - // - // Writing side: if the next packet to write is smaller than - // connBufferSize-4, this buffer can be used to create a - // packet. It will contain both the size and sequence header, - // and the contents of the packet. - // Call startEphemeralPacket(length) to get a buffer. If length - // is smaller or equal than connBufferSize-4, this buffer will be used. - // Otherwise memory will be allocated for it. - buffer []byte - // Keep track of how and of the buffer we allocated for an // ephemeral packet on the read and write sides. // These fields are used by: @@ -201,75 +172,19 @@ func newConn(conn net.Conn) *Conn { reader: bufio.NewReaderSize(conn, connBufferSize), writer: bufio.NewWriterSize(conn, connBufferSize), sequence: 0, - buffer: make([]byte, connBufferSize), - } -} - -// readPacketDirect attempts to read a packet from the socket directly. -// It needs to be used for the first handshake packet the server receives, -// so we do't buffer the SSL negotiation packet. As a shortcut, only -// packets smaller than MaxPacketSize can be read here. -func (c *Conn) readPacketDirect() ([]byte, error) { - var header [4]byte - if _, err := io.ReadFull(c.conn, header[:]); err != nil { - // Propagate as is so server can ignore this kind of error - // Same as readEphemeralPacket() - if err == io.EOF { - return nil, err - } - // Treat connection reset by peer as io.EOF, otherwise is too spammy. - if strings.HasSuffix(err.Error(), "read: connection reset by peer") { - return nil, io.EOF - } - return nil, fmt.Errorf("io.ReadFull(header size) failed: %v", err) } - - sequence := uint8(header[3]) - if sequence != c.sequence { - return nil, fmt.Errorf("invalid sequence, expected %v got %v", c.sequence, sequence) - } - - c.sequence++ - - length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16) - if length <= cap(c.buffer) { - // Fast path: read into buffer, we're good. - c.buffer = c.buffer[:length] - if _, err := io.ReadFull(c.conn, c.buffer); err != nil { - return nil, fmt.Errorf("io.ReadFull(direct packet body of length %v) failed: %v", length, err) - } - return c.buffer, nil - } - - // Sanity check - if length == MaxPacketSize { - return nil, fmt.Errorf("readPacketDirect doesn't support more than one packet") - } - - // Slow path, revert to allocating. - data := make([]byte, length) - if _, err := io.ReadFull(c.conn, data); err != nil { - return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err) - } - return data, nil } -// readEphemeralPacket attempts to read a packet into c.buffer. Do -// not use this method if the contents of the packet needs to be kept -// after the next readEphemeralPacket. If the packet is bigger than -// connBufferSize, we revert to using the same behavior as a regular -// readPacket. recycleReadPacket() has to be called after this method -// is used, and before we read or write any other packet on the connection. -// -// Note if the connection is closed already, an error will be -// returned, and it may not be io.EOF. If the connection closes while -// we are stuck waiting for data, an error will also be returned, and -// it most likely will be io.EOF. -func (c *Conn) readEphemeralPacket() ([]byte, error) { +func (c *Conn) readEphemeralPacketHelper(direct bool) ([]byte, error) { if c.currentEphemeralPolicy != ephemeralUnused { panic(fmt.Errorf("readEphemeralPacket: unexpected currentEphemeralPolicy: %v", c.currentEphemeralPolicy)) } + var r io.Reader = c.reader + if direct { + r = c.conn + } + // Note io.ReadFull will return two different types of errors: // 1. if the socket is already closed, and the go runtime knows it, // then ReadFull will return an error (different than EOF), @@ -277,7 +192,7 @@ func (c *Conn) readEphemeralPacket() ([]byte, error) { // 2. if the socket is not closed while we start the read, // but gets closed after the read is started, we'll get io.EOF. var header [4]byte - if _, err := io.ReadFull(c.reader, header[:]); err != nil { + if _, err := io.ReadFull(r, header[:]); err != nil { // The special casing of propagating io.EOF up // is used by the server side only, to suppress an error // message if a client just disconnects. @@ -303,26 +218,21 @@ func (c *Conn) readEphemeralPacket() ([]byte, error) { // exactly size MaxPacketSize. return nil, nil } - if length <= cap(c.buffer) { - // Fast path: read into buffer, we're good. - c.currentEphemeralPolicy = ephemeralReadGlobalBuffer - c.buffer = c.buffer[:length] - if _, err := io.ReadFull(c.reader, c.buffer); err != nil { - return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err) - } - return c.buffer, nil - } - // Slightly slower path: single packet. Use the bufPool. + // Use the bufPool. if length < MaxPacketSize { c.currentEphemeralPolicy = ephemeralReadSingleBuffer c.currentEphemeralReadBuffer = bufPool.Get(length) - if _, err := io.ReadFull(c.reader, *c.currentEphemeralReadBuffer); err != nil { + if _, err := io.ReadFull(r, *c.currentEphemeralReadBuffer); err != nil { return nil, fmt.Errorf("io.ReadFull(packet body of length %v) failed: %v", length, err) } return *c.currentEphemeralReadBuffer, nil } + if direct { + return nil, fmt.Errorf("readEphemeralPacketDirect doesn't support more than one packet") + } + // Much slower path, revert to allocating everything from scratch. // We're going to concatenate a lot of data anyway, can't really // optimize this code path easily. @@ -351,12 +261,30 @@ func (c *Conn) readEphemeralPacket() ([]byte, error) { return data, nil } +// readEphemeralPacketDirect attempts to read a packet from the socket directly. +// It needs to be used for the first handshake packet the server receives, +// so we do't buffer the SSL negotiation packet. As a shortcut, only +// packets smaller than MaxPacketSize can be read here. +func (c *Conn) readEphemeralPacketDirect() ([]byte, error) { + return c.readEphemeralPacketHelper(true) +} + +// readEphemeralPacket attempts to read a packet into buffer from sync.Pool. Do +// not use this method if the contents of the packet needs to be kept +// after the next readEphemeralPacket. +// +// Note if the connection is closed already, an error will be +// returned, and it may not be io.EOF. If the connection closes while +// we are stuck waiting for data, an error will also be returned, and +// it most likely will be io.EOF. +func (c *Conn) readEphemeralPacket() ([]byte, error) { + return c.readEphemeralPacketHelper(false) +} + // recycleReadPacket recycles the read packet. It needs to be called // after readEphemeralPacket was called. func (c *Conn) recycleReadPacket() { switch c.currentEphemeralPolicy { - case ephemeralReadGlobalBuffer: - // We used small built-in buffer, nothing to do. case ephemeralReadSingleBuffer: // We are using the pool, put the buffer back in. bufPool.Put(c.currentEphemeralReadBuffer) @@ -365,7 +293,7 @@ func (c *Conn) recycleReadPacket() { // We allocated a one-time buffer we can't re-use. // Nothing to do. Nil out for safety. c.currentEphemeralReadBuffer = nil - case ephemeralUnused, ephemeralWriteGlobalBuffer, ephemeralWriteSingleBuffer, ephemeralWriteBigBuffer: + case ephemeralUnused, ephemeralWriteSingleBuffer, ephemeralWriteBigBuffer: // Programming error. panic(fmt.Errorf("trying to call recycleReadPacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy)) } @@ -515,20 +443,7 @@ func (c *Conn) startEphemeralPacket(length int) []byte { panic("startEphemeralPacket cannot be used while a packet is already started.") } - // Fast path: we can reuse a single memory buffer for - // both the header and the data. - if length <= cap(c.buffer)-4 { - c.currentEphemeralPolicy = ephemeralWriteGlobalBuffer - c.buffer = c.buffer[:length+4] - c.buffer[0] = byte(length) - c.buffer[1] = byte(length >> 8) - c.buffer[2] = byte(length >> 16) - c.buffer[3] = c.sequence - c.sequence++ - return c.buffer[4:] - } - - // Slower path: we can use a single buffer for both the header and the data, but it has to be allocated. + // get buffer from pool if length < MaxPacketSize { c.currentEphemeralPolicy = ephemeralWriteSingleBuffer @@ -560,14 +475,6 @@ func (c *Conn) writeEphemeralPacket(direct bool) error { } switch c.currentEphemeralPolicy { - case ephemeralWriteGlobalBuffer: - // Just write c.buffer as a single buffer. - // It has both header and data. - if n, err := w.Write(c.buffer); err != nil { - return fmt.Errorf("Conn %v: Write(c.buffer) failed: %v", c.ID(), err) - } else if n != len(c.buffer) { - return fmt.Errorf("Conn %v: Write(c.buffer) returned a short write: %v < %v", c.ID(), n, len(c.buffer)) - } case ephemeralWriteSingleBuffer: // Write the allocated buffer as a single buffer. // It has both header and data. @@ -586,7 +493,7 @@ func (c *Conn) writeEphemeralPacket(direct bool) error { if direct { return c.flush() } - case ephemeralUnused, ephemeralReadGlobalBuffer, ephemeralReadSingleBuffer, ephemeralReadBigBuffer: + case ephemeralUnused, ephemeralReadSingleBuffer, ephemeralReadBigBuffer: // Programming error. panic(fmt.Errorf("Conn %v: trying to call writeEphemeralPacket while currentEphemeralPolicy is %v", c.ID(), c.currentEphemeralPolicy)) } @@ -598,8 +505,6 @@ func (c *Conn) writeEphemeralPacket(direct bool) error { // after writeEphemeralPacket was called. func (c *Conn) recycleWritePacket() { switch c.currentEphemeralPolicy { - case ephemeralWriteGlobalBuffer: - // We used small built-in buffer, nothing to do. case ephemeralWriteSingleBuffer: // Release our reference so the buffer can be gced bufPool.Put(c.currentEphemeralWriteBuffer) @@ -609,8 +514,7 @@ func (c *Conn) recycleWritePacket() { // N.B. Unlike the read packet, we actually assign the big buffer to currentEphemeralReadBuffer, // so we should remove our reference to it. c.currentEphemeralWriteBuffer = nil - case ephemeralUnused, ephemeralReadGlobalBuffer, - ephemeralReadSingleBuffer, ephemeralReadBigBuffer: + case ephemeralUnused, ephemeralReadSingleBuffer, ephemeralReadBigBuffer: // Programming error. panic(fmt.Errorf("trying to call recycleWritePacket while currentEphemeralPolicy is %d", c.currentEphemeralPolicy)) } diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index 50b99714d31..5989848ebd1 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -150,11 +150,14 @@ func verifyPacketComms(t *testing.T, cConn, sConn *Conn, data []byte) { verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacketDirect, sConn.readEphemeralPacket) sConn.recycleReadPacket() - // All three writes, with readPacketDirect, if size allows it. + // All three writes, with readEphemeralPacketDirect, if size allows it. if len(data) < MaxPacketSize { - verifyPacketCommsSpecific(t, cConn, data, useWritePacket, sConn.readPacketDirect) - verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacket, sConn.readPacketDirect) - verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacketDirect, sConn.readPacketDirect) + verifyPacketCommsSpecific(t, cConn, data, useWritePacket, sConn.readEphemeralPacketDirect) + sConn.recycleReadPacket() + verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacket, sConn.readEphemeralPacketDirect) + sConn.recycleReadPacket() + verifyPacketCommsSpecific(t, cConn, data, useWriteEphemeralPacketDirect, sConn.readEphemeralPacketDirect) + sConn.recycleReadPacket() } } diff --git a/go/mysql/server.go b/go/mysql/server.go index 6b634b9c8cc..15410decf31 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -204,7 +204,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti // Wait for the client response. This has to be a direct read, // so we don't buffer the TLS negotiation packets. - response, err := c.readPacketDirect() + response, err := c.readEphemeralPacketDirect() if err != nil { // Don't log EOF errors. They cause too much spam, same as main read loop. if err != io.EOF { @@ -218,6 +218,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti return } + c.recycleReadPacket() + if c.Capabilities&CapabilityClientSSL > 0 { // SSL was enabled. We need to re-read the auth packet. response, err = c.readEphemeralPacket()