Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {

// Remember a subset of the capabilities, so we can use them
// later in the protocol.
c.Capabilities = capabilities & (CapabilityClientDeprecateEOF)
c.Capabilities = 0
if !params.DisableClientDeprecateEOF {
c.Capabilities = capabilities & (CapabilityClientDeprecateEOF)
}

// Handle switch to SSL if necessary.
if params.Flags&CapabilityClientSSL > 0 {
Expand Down
199 changes: 188 additions & 11 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@ package mysql

import (
"bufio"
"errors"
"fmt"
"io"
"net"
"strings"
"sync"
"time"

"vitess.io/vitess/go/bucketpool"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/sync2"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
)

Expand Down Expand Up @@ -66,6 +70,9 @@ type Conn struct {
// If there are any ongoing reads or writes, they may get interrupted.
conn net.Conn

// For server-side connections, listener points to the server object.
listener *Listener

// ConnectionID is set:
// - at Connect() time for clients, with the value returned by
// the server.
Expand Down Expand Up @@ -164,15 +171,18 @@ func newConn(conn net.Conn) *Conn {
}

// newServerConn should be used to create server connections.
// The only difference from "client" newConn is ability to control buffer size
// for reads.
func newServerConn(conn net.Conn, connReadBufferSize int) *Conn {
//
// It stashes a reference to the listener to be able to determine if
// the server is shutting down, and has the ability to control buffer
// size for reads.
func newServerConn(conn net.Conn, listener *Listener) *Conn {
c := &Conn{
conn: conn,
closed: sync2.NewAtomicBool(false),
conn: conn,
listener: listener,
closed: sync2.NewAtomicBool(false),
}
if connReadBufferSize > 0 {
c.bufferedReader = bufio.NewReaderSize(conn, connReadBufferSize)
if listener.connReadBufferSize > 0 {
c.bufferedReader = bufio.NewReaderSize(conn, listener.connReadBufferSize)
}
return c
}
Expand Down Expand Up @@ -673,6 +683,166 @@ func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error {
return c.writeEphemeralPacket()
}

// handleNextCommand is called in the server loop to process
// incoming packets.
func (c *Conn) handleNextCommand(handler Handler) error {
c.sequence = 0
data, err := c.readEphemeralPacket()
if err != nil {
// Don't log EOF errors. They cause too much spam.
// Note the EOF detection is not 100%
// guaranteed, in the case where the client
// connection is already closed before we call
// 'readEphemeralPacket'. This is a corner
// case though, and very unlikely to happen,
// and the only downside is we log a bit more then.
if err != io.EOF {
log.Errorf("Error reading packet from %s: %v", c, err)
}
return err
}

switch data[0] {
case ComQuit:
c.recycleReadPacket()
return errors.New("ComQuit")
case ComInitDB:
db := c.parseComInitDB(data)
c.recycleReadPacket()
c.SchemaName = db
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Error writing ComInitDB result to %s: %v", c, err)
return err
}
case ComQuery:
// flush is called at the end of this block.
// We cannot encapsulate it with a defer inside a func because
// we have to return from this func if it fails.
c.startWriterBuffering()

queryStart := time.Now()
query := c.parseComQuery(data)
c.recycleReadPacket()
fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false

err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error {
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}

if !fieldSent {
fieldSent = true

if len(qr.Fields) == 0 {
sendFinished = true

// A successful callback with no fields means that this was a
// DML or other write-only operation.
//
// We should not send any more packets after this, but make sure
// to extract the affected rows and last insert id from the result
// struct here since clients expect it.
return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, handler.WarningCount(c))
}
if err := c.writeFields(qr); err != nil {
return err
}
}

return c.writeRows(qr)
})

// If no field was sent, we expect an error.
if !fieldSent {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
}
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return err
}

// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
}
}

timings.Record(queryTimingKey, queryStart)

if err := c.flush(); err != nil {
log.Errorf("Conn %v: Flush() failed: %v", c.ID(), err)
return err
}

case ComPing:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll probably have to think about this when resolving the merge conflict. When the change was made to ComPing, the listener was accessible, but now it's not.

Copy link
Copy Markdown
Member Author

@demmer demmer Oct 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just rebased the branch and handled this by stashing a reference to the listener in the Conn when used for server-side connections.

I think it's a reasonably safe invariant that Conn.handleNextCommand is only called for server-side connections for which newServerConn was used to create it.

c.recycleReadPacket()
// Return error if listener was shut down and OK otherwise
if c.listener.isShutdown() {
if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil {
log.Errorf("Error writing ComPing error to %s: %v", c, err)
return err
}
} else {
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Error writing ComPing result to %s: %v", c, err)
return err
}
}
case ComSetOption:
if operation, ok := c.parseComSetOption(data); ok {
switch operation {
case 0:
c.Capabilities |= CapabilityClientMultiStatements
case 1:
c.Capabilities &^= CapabilityClientMultiStatements
default:
log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
}
}
if err := c.writeEndResult(false, 0, 0, 0); err != nil {
log.Errorf("Error writeEndResult error %v ", err)
return err
}
} else {
log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
}
}
default:
log.Errorf("Got unhandled packet from %s, returning error: %v", c, data)
c.recycleReadPacket()
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]); err != nil {
log.Errorf("Error writing error packet to %s: %s", c, err)
return err
}
}

return nil
}

//
// Packet parsing methods, for generic packets.
//
Expand All @@ -697,14 +867,21 @@ func isEOFPacket(data []byte) bool {
return data[0] == EOFPacket && len(data) < 9
}

// parseEOFPacket returns true if there are more results to receive.
func parseEOFPacket(data []byte) (bool, error) {
// parseEOFPacket returns the warning count and a boolean to indicate if there
// are more results to receive.
//
// Note: This is only valid on actual EOF packets and not on OK packets with the EOF
// type code set, i.e. should not be used if ClientDeprecateEOF is set.
func parseEOFPacket(data []byte) (warnings uint16, more bool, err error) {
// The warning count is in position 2 & 3
warnings, _, ok := readUint16(data, 1)

// The status flag is in position 4 & 5
statusFlags, _, ok := readUint16(data, 3)
if !ok {
return false, fmt.Errorf("invalid EOF packet statusFlags: %v", data)
return 0, false, fmt.Errorf("invalid EOF packet statusFlags: %v", data)
}
return (statusFlags & ServerMoreResultsExists) != 0, nil
return warnings, (statusFlags & ServerMoreResultsExists) != 0, nil
}

func parseOKPacket(data []byte) (uint64, uint64, uint16, uint16, error) {
Expand Down
4 changes: 4 additions & 0 deletions go/mysql/conn_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ type ConnParams struct {
// The following is only set when the deprecated "dbname" flags are
// supplied and will be removed.
DeprecatedDBName string

// The following is only set to force the client to connect without
// using CapabilityClientDeprecateEOF
DisableClientDeprecateEOF bool
}

// EnableSSL will set the right flag on the parameters.
Expand Down
70 changes: 67 additions & 3 deletions go/mysql/endtoend/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,23 +142,30 @@ func TestClientFoundRows(t *testing.T) {
}
}

func TestMultiResult(t *testing.T) {
func doTestMultiResult(t *testing.T, disableClientDeprecateEOF bool) {
ctx := context.Background()
connParams.DisableClientDeprecateEOF = disableClientDeprecateEOF

conn, err := mysql.Connect(ctx, &connParams)
expectNoError(t, err)
defer conn.Close()

connParams.DisableClientDeprecateEOF = false

expectFlag(t, "Negotiated ClientDeprecateEOF flag", (conn.Capabilities&mysql.CapabilityClientDeprecateEOF) != 0, !disableClientDeprecateEOF)
defer conn.Close()

qr, more, err := conn.ExecuteFetchMulti("select 1 from dual; set autocommit=1; select 1 from dual", 10, true)
expectNoError(t, err)
expectFlag(t, "ExecuteMultiFetch(multi result)", more, true)
expectRows(t, "ExecuteMultiFetch(multi result)", qr, 1)

qr, more, err = conn.ReadQueryResult(10, true)
qr, more, _, err = conn.ReadQueryResult(10, true)
expectNoError(t, err)
expectFlag(t, "ReadQueryResult(1)", more, true)
expectRows(t, "ReadQueryResult(1)", qr, 0)

qr, more, err = conn.ReadQueryResult(10, true)
qr, more, _, err = conn.ReadQueryResult(10, true)
expectNoError(t, err)
expectFlag(t, "ReadQueryResult(2)", more, false)
expectRows(t, "ReadQueryResult(2)", qr, 1)
Expand All @@ -172,6 +179,63 @@ func TestMultiResult(t *testing.T) {
expectNoError(t, err)
expectFlag(t, "ExecuteMultiFetch(no result)", more, false)
expectRows(t, "ExecuteMultiFetch(no result)", qr, 0)

// The ClientDeprecateEOF protocol change has a subtle twist in which an EOF or OK
// packet happens to have the status flags in the same position if the affected_rows
// and last_insert_id are both one byte long:
//
// https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html
// https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
//
// It turns out that there are no actual cases in which clients end up needing to make
// this distinction. If either affected_rows or last_insert_id are non-zero, the protocol
// sends an OK packet unilaterally which is properly parsed. If not, then regardless of the
// negotiated version, it can properly send the status flags.
//
result, err := conn.ExecuteFetch("create table a(id int, name varchar(128), primary key(id))", 0, false)
if err != nil {
t.Fatalf("create table failed: %v", err)
}
if result.RowsAffected != 0 {
t.Errorf("create table returned RowsAffected %v, was expecting 0", result.RowsAffected)
}

for i := 0; i < 255; i++ {
result, err := conn.ExecuteFetch(fmt.Sprintf("insert into a(id, name) values(%v, 'nice name %v')", 1000+i, i), 1000, true)
if err != nil {
t.Fatalf("ExecuteFetch(%v) failed: %v", i, err)
}
if result.RowsAffected != 1 {
t.Errorf("insert into returned RowsAffected %v, was expecting 1", result.RowsAffected)
}
}

qr, more, err = conn.ExecuteFetchMulti("update a set name = concat(name, ' updated'); select * from a; select count(*) from a", 300, true)
expectNoError(t, err)
expectFlag(t, "ExecuteMultiFetch(multi result)", more, true)
expectRows(t, "ExecuteMultiFetch(multi result)", qr, 255)

qr, more, _, err = conn.ReadQueryResult(300, true)
expectNoError(t, err)
expectFlag(t, "ReadQueryResult(1)", more, true)
expectRows(t, "ReadQueryResult(1)", qr, 255)

qr, more, _, err = conn.ReadQueryResult(300, true)
expectNoError(t, err)
expectFlag(t, "ReadQueryResult(2)", more, false)
expectRows(t, "ReadQueryResult(2)", qr, 1)

result, err = conn.ExecuteFetch("drop table a", 10, true)
if err != nil {
t.Fatalf("drop table failed: %v", err)
}
}

func TestMultiResultDeprecateEOF(t *testing.T) {
doTestMultiResult(t, false)
}
func TestMultiResultNoDeprecateEOF(t *testing.T) {
doTestMultiResult(t, true)
}

func expectNoError(t *testing.T, err error) {
Expand Down
Loading