diff --git a/go/mysql/client.go b/go/mysql/client.go index 3ef28912914..1cc117a8f6e 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -232,7 +232,10 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { // Remember a subset of the capabilities, so we can use them // later in the protocol. - c.Capabilities = capabilities & (CapabilityClientDeprecateEOF) + c.Capabilities = 0 + if !params.DisableClientDeprecateEOF { + c.Capabilities = capabilities & (CapabilityClientDeprecateEOF) + } // Handle switch to SSL if necessary. if params.Flags&CapabilityClientSSL > 0 { diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 68fe7cad211..15751f61116 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -18,14 +18,18 @@ package mysql import ( "bufio" + "errors" "fmt" "io" "net" "strings" "sync" + "time" "vitess.io/vitess/go/bucketpool" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/sync2" + "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -66,6 +70,9 @@ type Conn struct { // If there are any ongoing reads or writes, they may get interrupted. conn net.Conn + // For server-side connections, listener points to the server object. + listener *Listener + // ConnectionID is set: // - at Connect() time for clients, with the value returned by // the server. @@ -164,15 +171,18 @@ func newConn(conn net.Conn) *Conn { } // newServerConn should be used to create server connections. -// The only difference from "client" newConn is ability to control buffer size -// for reads. -func newServerConn(conn net.Conn, connReadBufferSize int) *Conn { +// +// It stashes a reference to the listener to be able to determine if +// the server is shutting down, and has the ability to control buffer +// size for reads. +func newServerConn(conn net.Conn, listener *Listener) *Conn { c := &Conn{ - conn: conn, - closed: sync2.NewAtomicBool(false), + conn: conn, + listener: listener, + closed: sync2.NewAtomicBool(false), } - if connReadBufferSize > 0 { - c.bufferedReader = bufio.NewReaderSize(conn, connReadBufferSize) + if listener.connReadBufferSize > 0 { + c.bufferedReader = bufio.NewReaderSize(conn, listener.connReadBufferSize) } return c } @@ -673,6 +683,166 @@ func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error { return c.writeEphemeralPacket() } +// handleNextCommand is called in the server loop to process +// incoming packets. +func (c *Conn) handleNextCommand(handler Handler) error { + c.sequence = 0 + data, err := c.readEphemeralPacket() + if err != nil { + // Don't log EOF errors. They cause too much spam. + // Note the EOF detection is not 100% + // guaranteed, in the case where the client + // connection is already closed before we call + // 'readEphemeralPacket'. This is a corner + // case though, and very unlikely to happen, + // and the only downside is we log a bit more then. + if err != io.EOF { + log.Errorf("Error reading packet from %s: %v", c, err) + } + return err + } + + switch data[0] { + case ComQuit: + c.recycleReadPacket() + return errors.New("ComQuit") + case ComInitDB: + db := c.parseComInitDB(data) + c.recycleReadPacket() + c.SchemaName = db + if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { + log.Errorf("Error writing ComInitDB result to %s: %v", c, err) + return err + } + case ComQuery: + // flush is called at the end of this block. + // We cannot encapsulate it with a defer inside a func because + // we have to return from this func if it fails. + c.startWriterBuffering() + + queryStart := time.Now() + query := c.parseComQuery(data) + c.recycleReadPacket() + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + + err := handler.ComQuery(c, query, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !fieldSent { + fieldSent = true + + if len(qr.Fields) == 0 { + sendFinished = true + + // A successful callback with no fields means that this was a + // DML or other write-only operation. + // + // We should not send any more packets after this, but make sure + // to extract the affected rows and last insert id from the result + // struct here since clients expect it. + return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, handler.WarningCount(c)) + } + if err := c.writeFields(qr); err != nil { + return err + } + } + + return c.writeRows(qr) + }) + + // If no field was sent, we expect an error. + if !fieldSent { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Error writing query error to %s: %v", c, werr) + return werr + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return err + } + + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return err + } + } + } + + timings.Record(queryTimingKey, queryStart) + + if err := c.flush(); err != nil { + log.Errorf("Conn %v: Flush() failed: %v", c.ID(), err) + return err + } + + case ComPing: + c.recycleReadPacket() + // Return error if listener was shut down and OK otherwise + if c.listener.isShutdown() { + if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil { + log.Errorf("Error writing ComPing error to %s: %v", c, err) + return err + } + } else { + if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { + log.Errorf("Error writing ComPing result to %s: %v", c, err) + return err + } + } + case ComSetOption: + if operation, ok := c.parseComSetOption(data); ok { + switch operation { + case 0: + c.Capabilities |= CapabilityClientMultiStatements + case 1: + c.Capabilities &^= CapabilityClientMultiStatements + default: + log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { + log.Errorf("Error writing error packet to client: %v", err) + return err + } + } + if err := c.writeEndResult(false, 0, 0, 0); err != nil { + log.Errorf("Error writeEndResult error %v ", err) + return err + } + } else { + log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { + log.Errorf("Error writing error packet to client: %v", err) + return err + } + } + default: + log.Errorf("Got unhandled packet from %s, returning error: %v", c, data) + c.recycleReadPacket() + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]); err != nil { + log.Errorf("Error writing error packet to %s: %s", c, err) + return err + } + } + + return nil +} + // // Packet parsing methods, for generic packets. // @@ -697,14 +867,21 @@ func isEOFPacket(data []byte) bool { return data[0] == EOFPacket && len(data) < 9 } -// parseEOFPacket returns true if there are more results to receive. -func parseEOFPacket(data []byte) (bool, error) { +// parseEOFPacket returns the warning count and a boolean to indicate if there +// are more results to receive. +// +// Note: This is only valid on actual EOF packets and not on OK packets with the EOF +// type code set, i.e. should not be used if ClientDeprecateEOF is set. +func parseEOFPacket(data []byte) (warnings uint16, more bool, err error) { + // The warning count is in position 2 & 3 + warnings, _, ok := readUint16(data, 1) + // The status flag is in position 4 & 5 statusFlags, _, ok := readUint16(data, 3) if !ok { - return false, fmt.Errorf("invalid EOF packet statusFlags: %v", data) + return 0, false, fmt.Errorf("invalid EOF packet statusFlags: %v", data) } - return (statusFlags & ServerMoreResultsExists) != 0, nil + return warnings, (statusFlags & ServerMoreResultsExists) != 0, nil } func parseOKPacket(data []byte) (uint64, uint64, uint16, uint16, error) { diff --git a/go/mysql/conn_params.go b/go/mysql/conn_params.go index ae739c9b59c..5af81b1202a 100644 --- a/go/mysql/conn_params.go +++ b/go/mysql/conn_params.go @@ -38,6 +38,10 @@ type ConnParams struct { // The following is only set when the deprecated "dbname" flags are // supplied and will be removed. DeprecatedDBName string + + // The following is only set to force the client to connect without + // using CapabilityClientDeprecateEOF + DisableClientDeprecateEOF bool } // EnableSSL will set the right flag on the parameters. diff --git a/go/mysql/endtoend/client_test.go b/go/mysql/endtoend/client_test.go index e2ad2d3fb2e..0a190944952 100644 --- a/go/mysql/endtoend/client_test.go +++ b/go/mysql/endtoend/client_test.go @@ -142,23 +142,30 @@ func TestClientFoundRows(t *testing.T) { } } -func TestMultiResult(t *testing.T) { +func doTestMultiResult(t *testing.T, disableClientDeprecateEOF bool) { ctx := context.Background() + connParams.DisableClientDeprecateEOF = disableClientDeprecateEOF + conn, err := mysql.Connect(ctx, &connParams) expectNoError(t, err) defer conn.Close() + connParams.DisableClientDeprecateEOF = false + + expectFlag(t, "Negotiated ClientDeprecateEOF flag", (conn.Capabilities&mysql.CapabilityClientDeprecateEOF) != 0, !disableClientDeprecateEOF) + defer conn.Close() + qr, more, err := conn.ExecuteFetchMulti("select 1 from dual; set autocommit=1; select 1 from dual", 10, true) expectNoError(t, err) expectFlag(t, "ExecuteMultiFetch(multi result)", more, true) expectRows(t, "ExecuteMultiFetch(multi result)", qr, 1) - qr, more, err = conn.ReadQueryResult(10, true) + qr, more, _, err = conn.ReadQueryResult(10, true) expectNoError(t, err) expectFlag(t, "ReadQueryResult(1)", more, true) expectRows(t, "ReadQueryResult(1)", qr, 0) - qr, more, err = conn.ReadQueryResult(10, true) + qr, more, _, err = conn.ReadQueryResult(10, true) expectNoError(t, err) expectFlag(t, "ReadQueryResult(2)", more, false) expectRows(t, "ReadQueryResult(2)", qr, 1) @@ -172,6 +179,63 @@ func TestMultiResult(t *testing.T) { expectNoError(t, err) expectFlag(t, "ExecuteMultiFetch(no result)", more, false) expectRows(t, "ExecuteMultiFetch(no result)", qr, 0) + + // The ClientDeprecateEOF protocol change has a subtle twist in which an EOF or OK + // packet happens to have the status flags in the same position if the affected_rows + // and last_insert_id are both one byte long: + // + // https://dev.mysql.com/doc/internals/en/packet-EOF_Packet.html + // https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html + // + // It turns out that there are no actual cases in which clients end up needing to make + // this distinction. If either affected_rows or last_insert_id are non-zero, the protocol + // sends an OK packet unilaterally which is properly parsed. If not, then regardless of the + // negotiated version, it can properly send the status flags. + // + result, err := conn.ExecuteFetch("create table a(id int, name varchar(128), primary key(id))", 0, false) + if err != nil { + t.Fatalf("create table failed: %v", err) + } + if result.RowsAffected != 0 { + t.Errorf("create table returned RowsAffected %v, was expecting 0", result.RowsAffected) + } + + for i := 0; i < 255; i++ { + result, err := conn.ExecuteFetch(fmt.Sprintf("insert into a(id, name) values(%v, 'nice name %v')", 1000+i, i), 1000, true) + if err != nil { + t.Fatalf("ExecuteFetch(%v) failed: %v", i, err) + } + if result.RowsAffected != 1 { + t.Errorf("insert into returned RowsAffected %v, was expecting 1", result.RowsAffected) + } + } + + qr, more, err = conn.ExecuteFetchMulti("update a set name = concat(name, ' updated'); select * from a; select count(*) from a", 300, true) + expectNoError(t, err) + expectFlag(t, "ExecuteMultiFetch(multi result)", more, true) + expectRows(t, "ExecuteMultiFetch(multi result)", qr, 255) + + qr, more, _, err = conn.ReadQueryResult(300, true) + expectNoError(t, err) + expectFlag(t, "ReadQueryResult(1)", more, true) + expectRows(t, "ReadQueryResult(1)", qr, 255) + + qr, more, _, err = conn.ReadQueryResult(300, true) + expectNoError(t, err) + expectFlag(t, "ReadQueryResult(2)", more, false) + expectRows(t, "ReadQueryResult(2)", qr, 1) + + result, err = conn.ExecuteFetch("drop table a", 10, true) + if err != nil { + t.Fatalf("drop table failed: %v", err) + } +} + +func TestMultiResultDeprecateEOF(t *testing.T) { + doTestMultiResult(t, false) +} +func TestMultiResultNoDeprecateEOF(t *testing.T) { + doTestMultiResult(t, true) } func expectNoError(t *testing.T, err error) { diff --git a/go/mysql/endtoend/query_test.go b/go/mysql/endtoend/query_test.go index c432bdef5fa..185a6de5627 100644 --- a/go/mysql/endtoend/query_test.go +++ b/go/mysql/endtoend/query_test.go @@ -237,3 +237,57 @@ func readRowsUsingStream(t *testing.T, conn *mysql.Conn, expectedCount int) { } conn.CloseResult() } + +func doTestWarnings(t *testing.T, disableClientDeprecateEOF bool) { + ctx := context.Background() + + connParams.DisableClientDeprecateEOF = disableClientDeprecateEOF + + conn, err := mysql.Connect(ctx, &connParams) + expectNoError(t, err) + defer conn.Close() + + connParams.DisableClientDeprecateEOF = false + + expectFlag(t, "Negotiated ClientDeprecateEOF flag", (conn.Capabilities&mysql.CapabilityClientDeprecateEOF) != 0, !disableClientDeprecateEOF) + defer conn.Close() + + result, err := conn.ExecuteFetch("create table a(id int, val int not null, primary key(id))", 0, false) + if err != nil { + t.Fatalf("create table failed: %v", err) + } + if result.RowsAffected != 0 { + t.Errorf("create table returned RowsAffected %v, was expecting 0", result.RowsAffected) + } + + // Disable strict mode + result, err = conn.ExecuteFetch("set session sql_mode=''", 0, false) + if err != nil { + t.Fatalf("disable strict mode failed: %v", err) + } + + // Try a simple insert with a null value + result, warnings, err := conn.ExecuteFetchWithWarningCount("insert into a(id) values(10)", 1000, true) + if err != nil { + t.Fatalf("insert failed: %v", err) + } + if result.RowsAffected != 1 || len(result.Rows) != 0 { + t.Errorf("unexpected result for insert: %v", result) + } + if warnings != 1 { + t.Errorf("unexpected result for warnings: %v", warnings) + } + + result, err = conn.ExecuteFetch("drop table a", 0, false) + if err != nil { + t.Fatalf("create table failed: %v", err) + } +} + +func TestWarningsDeprecateEOF(t *testing.T) { + doTestWarnings(t, false) +} + +func TestWarningsNoDeprecateEOF(t *testing.T) { + doTestWarnings(t, true) +} diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 4aaf965b02a..2272b17d7cb 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -320,6 +320,11 @@ func (db *DB) ComQuery(c *mysql.Conn, query string, callback func(*sqltypes.Resu return db.Handler.HandleQuery(c, query, callback) } +// WarningCount is part of the mysql.Handler interface. +func (db *DB) WarningCount(c *mysql.Conn) uint16 { + return 0 +} + // HandleQuery is the default implementation of the QueryHandler interface func (db *DB) HandleQuery(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { if db.AllowAll { diff --git a/go/mysql/query.go b/go/mysql/query.go index 067f01b4d96..ebaabb27a6a 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -311,22 +311,45 @@ func (c *Conn) ExecuteFetchMulti(query string, maxrows int, wantfields bool) (re return nil, false, err } - return c.ReadQueryResult(maxrows, wantfields) + res, more, _, err := c.ReadQueryResult(maxrows, wantfields) + return res, more, err +} + +// ExecuteFetchWithWarningCount is for fetching results and a warning count +// Note: In a future iteration this should be abolished and merged into the +// ExecuteFetch API. +func (c *Conn) ExecuteFetchWithWarningCount(query string, maxrows int, wantfields bool) (result *sqltypes.Result, warnings uint16, err error) { + defer func() { + if err != nil { + if sqlerr, ok := err.(*SQLError); ok { + sqlerr.Query = query + } + } + }() + + // Send the query as a COM_QUERY packet. + if err = c.WriteComQuery(query); err != nil { + return nil, 0, err + } + + res, _, warnings, err := c.ReadQueryResult(maxrows, wantfields) + return res, warnings, err } // ReadQueryResult gets the result from the last written query. -func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (result *sqltypes.Result, more bool, err error) { +func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (result *sqltypes.Result, more bool, warnings uint16, err error) { // Get the result. - affectedRows, lastInsertID, colNumber, more, err := c.readComQueryResponse() + affectedRows, lastInsertID, colNumber, more, warnings, err := c.readComQueryResponse() if err != nil { - return nil, false, err + return nil, false, 0, err } + if colNumber == 0 { // OK packet, means no results. Just use the numbers. return &sqltypes.Result{ RowsAffected: affectedRows, InsertID: lastInsertID, - }, more, nil + }, more, warnings, nil } fields := make([]querypb.Field, colNumber) @@ -341,11 +364,11 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (result *sqltypes.R if wantfields { if err := c.readColumnDefinition(result.Fields[i], i); err != nil { - return nil, false, err + return nil, false, 0, err } } else { if err := c.readColumnDefinitionType(result.Fields[i], i); err != nil { - return nil, false, err + return nil, false, 0, err } } } @@ -354,19 +377,21 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (result *sqltypes.R // EOF is only present here if it's not deprecated. data, err := c.readEphemeralPacket() if err != nil { - return nil, false, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) + return nil, false, 0, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) } if isEOFPacket(data) { + // This is what we expect. // Warnings and status flags are ignored. c.recycleReadPacket() // goto: read row loop + } else if isErrorPacket(data) { defer c.recycleReadPacket() - return nil, false, ParseErrorPacket(data) + return nil, false, 0, ParseErrorPacket(data) } else { defer c.recycleReadPacket() - return nil, false, fmt.Errorf("unexpected packet after fields: %v", data) + return nil, false, 0, fmt.Errorf("unexpected packet after fields: %v", data) } } @@ -374,7 +399,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (result *sqltypes.R for { data, err := c.ReadPacket() if err != nil { - return nil, false, err + return nil, false, 0, err } if isEOFPacket(data) { @@ -383,28 +408,41 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (result *sqltypes.R result.Fields = nil } result.RowsAffected = uint64(len(result.Rows)) - more, err := parseEOFPacket(data) - if err != nil { - return nil, false, err + + // The deprecated EOF packets change means that this is either an + // EOF packet or an OK packet with the EOF type code. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + warnings, more, err = parseEOFPacket(data) + if err != nil { + return nil, false, 0, err + } + } else { + var statusFlags uint16 + _, _, statusFlags, warnings, err = parseOKPacket(data) + if err != nil { + return nil, false, 0, err + } + more = (statusFlags & ServerMoreResultsExists) != 0 } - return result, more, nil + return result, more, warnings, nil + } else if isErrorPacket(data) { // Error packet. - return nil, false, ParseErrorPacket(data) + return nil, false, 0, ParseErrorPacket(data) } // Check we're not over the limit before we add more. if len(result.Rows) == maxrows { if err := c.drainResults(); err != nil { - return nil, false, err + return nil, false, 0, err } - return nil, false, NewSQLError(ERVitessMaxRowsExceeded, SSUnknownSQLState, "Row count exceeded %d", maxrows) + return nil, false, 0, NewSQLError(ERVitessMaxRowsExceeded, SSUnknownSQLState, "Row count exceeded %d", maxrows) } // Regular row. row, err := c.parseRow(data, result.Fields) if err != nil { - return nil, false, err + return nil, false, 0, err } result.Rows = append(result.Rows, row) } @@ -428,36 +466,35 @@ func (c *Conn) drainResults() error { } } -func (c *Conn) readComQueryResponse() (uint64, uint64, int, bool, error) { +func (c *Conn) readComQueryResponse() (affectedRows uint64, lastInsertID uint64, status int, more bool, warnings uint16, err error) { data, err := c.readEphemeralPacket() if err != nil { - return 0, 0, 0, false, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) + return 0, 0, 0, false, 0, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) } defer c.recycleReadPacket() if len(data) == 0 { - return 0, 0, 0, false, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "invalid empty COM_QUERY response packet") + return 0, 0, 0, false, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "invalid empty COM_QUERY response packet") } switch data[0] { case OKPacket: - affectedRows, lastInsertID, status, _, err := parseOKPacket(data) - return affectedRows, lastInsertID, 0, (status & ServerMoreResultsExists) != 0, err + affectedRows, lastInsertID, status, warnings, err := parseOKPacket(data) + return affectedRows, lastInsertID, 0, (status & ServerMoreResultsExists) != 0, warnings, err case ErrPacket: // Error - return 0, 0, 0, false, ParseErrorPacket(data) + return 0, 0, 0, false, 0, ParseErrorPacket(data) case 0xfb: // Local infile - return 0, 0, 0, false, fmt.Errorf("not implemented") + return 0, 0, 0, false, 0, fmt.Errorf("not implemented") } - n, pos, ok := readLenEncInt(data, 0) if !ok { - return 0, 0, 0, false, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "cannot get column number") + return 0, 0, 0, false, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "cannot get column number") } if pos != len(data) { - return 0, 0, 0, false, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extra data in COM_QUERY response") + return 0, 0, 0, false, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "extra data in COM_QUERY response") } - return 0, 0, int(n), false, nil + return 0, 0, int(n), false, 0, nil } // @@ -598,20 +635,20 @@ func (c *Conn) writeRows(result *sqltypes.Result) error { // writeEndResult concludes the sending of a Result. // if more is set to true, then it means there are more results afterwords -func (c *Conn) writeEndResult(more bool) error { +func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warnings uint16) error { // Send either an EOF, or an OK packet. // See doc.go. - flag := c.StatusFlags + flags := c.StatusFlags if more { - flag |= ServerMoreResultsExists + flags |= ServerMoreResultsExists } if c.Capabilities&CapabilityClientDeprecateEOF == 0 { - if err := c.writeEOFPacket(flag, 0); err != nil { + if err := c.writeEOFPacket(flags, warnings); err != nil { return err } } else { // This will flush too. - if err := c.writeOKPacketWithEOFHeader(0, 0, flag, 0); err != nil { + if err := c.writeOKPacketWithEOFHeader(affectedRows, lastInsertID, flags, warnings); err != nil { return err } } diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index e8a957ed334..c2ae6cf4176 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -300,20 +300,24 @@ func checkQuery(t *testing.T, query string, sConn, cConn *Conn, result *sqltypes sConn.Capabilities = 0 cConn.Capabilities = 0 - checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */) - checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */) - checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */) - checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */) + checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, false /* warnings */) + checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */, false /* warnings */) + checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */, false /* warnings */) + checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */, false /* warnings */) + + checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, true /* warnings */) sConn.Capabilities = CapabilityClientDeprecateEOF cConn.Capabilities = CapabilityClientDeprecateEOF - checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */) - checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */) - checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */) - checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */) + checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, false /* warnings */) + checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, true /* allRows */, false /* warnings */) + checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, false /* allRows */, false /* warnings */) + checkQueryInternal(t, query, sConn, cConn, result, false /* wantfields */, false /* allRows */, false /* warnings */) + + checkQueryInternal(t, query, sConn, cConn, result, true /* wantfields */, true /* allRows */, true /* warnings */) } -func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *sqltypes.Result, wantfields, allRows bool) { +func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result *sqltypes.Result, wantfields, allRows, warnings bool) { if sConn.Capabilities&CapabilityClientDeprecateEOF > 0 { query += " NOEOF" @@ -331,6 +335,14 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result * query += " PARTIAL" } + var warningCount uint16 + if warnings { + query += " WARNINGS" + warningCount = 99 + } else { + query += " NOWARNINGS" + } + // Use a go routine to run ExecuteFetch. wg := sync.WaitGroup{} wg.Add(1) @@ -343,7 +355,7 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result * // Asking for just one row max. The results that have more will fail. maxrows = 1 } - got, err := cConn.ExecuteFetch(query, maxrows, wantfields) + got, gotWarnings, err := cConn.ExecuteFetchWithWarningCount(query, maxrows, wantfields) if !allRows && len(result.Rows) > 1 { if err == nil { t.Errorf("ExecuteFetch should have failed but got: %v", got) @@ -371,6 +383,10 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result * t.Fatalf("ExecuteFetch(wantfields=%v) returned:\n%v\nBut was expecting:\n%v", wantfields, got, expected) } + if gotWarnings != warningCount { + t.Errorf("ExecuteFetch(%v) expected %v warnings got %v", query, warningCount, gotWarnings) + } + // Test ExecuteStreamFetch, build a Result. expected = *result if err := cConn.ExecuteStreamFetch(query); err != nil { @@ -425,22 +441,16 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result * count-- } + handler := testHandler{ + result: result, + warnings: warningCount, + } + for i := 0; i < count; i++ { - comQuery, err := sConn.ReadPacket() + err := sConn.handleNextCommand(&handler) if err != nil { - t.Fatalf("server cannot read query: %v", err) - } - if comQuery[0] != ComQuery { - t.Fatalf("server got bad packet: %v", comQuery) - } - got := sConn.parseComQuery(comQuery) - if got != query { - t.Errorf("server got query '%v' but expected '%v'", got, query) - } - if err := writeResult(sConn, result); err != nil { - t.Errorf("Error writing result to client: %v", err) + t.Fatalf("error handling command: %v", err) } - sConn.sequence = 0 } wg.Wait() @@ -456,7 +466,7 @@ func writeResult(conn *Conn, result *sqltypes.Result) error { if err := conn.writeRows(result); err != nil { return err } - return conn.writeEndResult(false) + return conn.writeEndResult(false, 0, 0, 0) } func RowString(row []sqltypes.Value) string { diff --git a/go/mysql/server.go b/go/mysql/server.go index 6c00b1656f5..37d5e2b3e5b 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -18,7 +18,6 @@ package mysql import ( "crypto/tls" - "errors" "fmt" "io" "net" @@ -85,6 +84,13 @@ type Handler interface { // the first call to callback. So the Handler should not // hang on to the byte slice. ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error + + // WarningCount is called at the end of each query to obtain + // the value to be returned to the client in the EOF packet. + // Note that this will be called either in the context of the + // ComQuery callback if the result does not contain any fields, + // or after the last ComQuery call completes. + WarningCount(c *Conn) uint16 } // Listener is the MySQL server protocol listener. @@ -232,7 +238,7 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti if l.connReadTimeout != 0 || l.connWriteTimeout != 0 { conn = netutil.NewConnWithTimeouts(conn, l.connReadTimeout, l.connWriteTimeout) } - c := newServerConn(conn, l.connReadBufferSize) + c := newServerConn(conn, l) c.ConnectionID = connectionID // Catch panics, and close the connection in any case. @@ -378,149 +384,9 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } for { - c.sequence = 0 - data, err := c.readEphemeralPacket() + err := c.handleNextCommand(l.handler) if err != nil { - // Don't log EOF errors. They cause too much spam. - // Note the EOF detection is not 100% - // guaranteed, in the case where the client - // connection is already closed before we call - // 'readEphemeralPacket'. This is a corner - // case though, and very unlikely to happen, - // and the only downside is we log a bit more then. - if err != io.EOF { - log.Errorf("Error reading packet from %s: %v", c, err) - } - return - } - - switch data[0] { - case ComQuit: - c.recycleReadPacket() return - case ComInitDB: - db := c.parseComInitDB(data) - c.recycleReadPacket() - c.SchemaName = db - if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { - log.Errorf("Error writing ComInitDB result to %s: %v", c, err) - return - } - case ComQuery: - // flush is called at the end of this block. - // We cannot encapsulate it with a defer inside a func because - // we have to return from this func if it fails. - c.startWriterBuffering() - - queryStart := time.Now() - query := c.parseComQuery(data) - c.recycleReadPacket() - fieldSent := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false - err := l.handler.ComQuery(c, query, func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } - - if !fieldSent { - fieldSent = true - - if len(qr.Fields) == 0 { - sendFinished = true - // We should not send any more packets after this. - return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) - } - if err := c.writeFields(qr); err != nil { - return err - } - } - - return c.writeRows(qr) - }) - - // If no field was sent, we expect an error. - if !fieldSent { - // This is just a failsafe. Should never happen. - if err == nil || err == io.EOF { - err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) - } - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Error writing query error to %s: %v", c, werr) - return - } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) - return - } - - // Send the end packet only sendFinished is false (results were streamed). - if !sendFinished { - if err := c.writeEndResult(false); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return - } - } - } - - timings.Record(queryTimingKey, queryStart) - - if err := c.flush(); err != nil { - log.Errorf("Conn %v: Flush() failed: %v", c.ID(), err) - return - } - - case ComPing: - c.recycleReadPacket() - // Return error if listener was shut down and OK otherwise - if l.isShutdown() { - if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil { - log.Errorf("Error writing ComPing error to %s: %v", c, err) - return - } - } else { - if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { - log.Errorf("Error writing ComPing result to %s: %v", c, err) - return - } - } - case ComSetOption: - if operation, ok := c.parseComSetOption(data); ok { - switch operation { - case 0: - c.Capabilities |= CapabilityClientMultiStatements - case 1: - c.Capabilities &^= CapabilityClientMultiStatements - default: - log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { - log.Errorf("Error writing error packet to client: %v", err) - return - } - } - if err := c.writeEndResult(false); err != nil { - log.Errorf("Error writeEndResult error %v ", err) - return - } - } else { - log.Errorf("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { - log.Errorf("Error writing error packet to client: %v", err) - return - } - } - default: - log.Errorf("Got unhandled packet from %s, returning error: %v", c, data) - c.recycleReadPacket() - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]); err != nil { - log.Errorf("Error writing error packet to %s: %s", c, err) - return - } } } } diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index ddf4c01b943..7023df61bb0 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -65,7 +65,9 @@ var selectRowsResult = &sqltypes.Result{ type testHandler struct { lastConn *Conn + result *sqltypes.Result err error + warnings uint16 } func (th *testHandler) NewConnection(c *Conn) { @@ -76,6 +78,11 @@ func (th *testHandler) ConnectionClosed(c *Conn) { } func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { + if th.result != nil { + callback(th.result) + return nil + } + switch query { case "error": return th.err @@ -164,6 +171,10 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R return nil } +func (th *testHandler) WarningCount(c *Conn) uint16 { + return th.warnings +} + func getHostPort(t *testing.T, a net.Addr) (string, int) { // For the host name, we resolve 'localhost' into an address. // This works around a few travis issues where IPv6 is not 100% enabled. @@ -589,6 +600,23 @@ func TestServer(t *testing.T) { !strings.Contains(output, "2 rows in set") { t.Errorf("Unexpected output for 'select rows'") } + if strings.Contains(output, "warnings") { + t.Errorf("Unexpected warnings in 'select rows'") + } + + // Run a 'select rows' command with warnings + th.warnings = 13 + output, ok = runMysql(t, params, "select rows") + if !ok { + t.Fatalf("mysql failed: %v", output) + } + if !strings.Contains(output, "nice name") || + !strings.Contains(output, "nicer name") || + !strings.Contains(output, "2 rows in set") || + !strings.Contains(output, "13 warnings") { + t.Errorf("Unexpected output for 'select rows': %v", output) + } + th.warnings = 0 // If there's an error after streaming has started, // we should get a 2013 diff --git a/go/mysql/streaming_query.go b/go/mysql/streaming_query.go index f93ab193b23..c606a20f469 100644 --- a/go/mysql/streaming_query.go +++ b/go/mysql/streaming_query.go @@ -47,7 +47,7 @@ func (c *Conn) ExecuteStreamFetch(query string) (err error) { } // Get the result. - _, _, colNumber, _, err := c.readComQueryResponse() + _, _, colNumber, _, _, err := c.readComQueryResponse() if err != nil { return err } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 37628f9cb3c..d1c9ed188e4 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -158,6 +158,10 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq return callback(result) } +func (vh *vtgateHandler) WarningCount(c *mysql.Conn) uint16 { + return 0 +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 70281f33ca1..c873206e6ad 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -43,6 +43,10 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } +func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { + return 0 +} + func TestConnectionUnixSocket(t *testing.T) { th := &testHandler{} diff --git a/go/vt/vtqueryserver/endtoend_test.go b/go/vt/vtqueryserver/endtoend_test.go index dc4fb82e426..fce6827ccc9 100644 --- a/go/vt/vtqueryserver/endtoend_test.go +++ b/go/vt/vtqueryserver/endtoend_test.go @@ -481,7 +481,7 @@ func TestQueryDeadline(t *testing.T) { t.Errorf("Unexpected error code: %d, want %d", got, want) } - _, _, err = conn.ReadQueryResult(1000, false) + _, _, _, err = conn.ReadQueryResult(1000, false) if err != nil { t.Errorf("unexpected error %v", err) } diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index ab0327b6273..804cd207f8b 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -132,6 +132,10 @@ func (mh *proxyHandler) ComQuery(c *mysql.Conn, query string, callback func(*sql return callback(result) } +func (mh *proxyHandler) WarningCount(c *mysql.Conn) uint16 { + return 0 +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener diff --git a/go/vt/vtqueryserver/plugin_mysql_server_test.go b/go/vt/vtqueryserver/plugin_mysql_server_test.go index 40eab721d26..ca0f6b806cd 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server_test.go +++ b/go/vt/vtqueryserver/plugin_mysql_server_test.go @@ -43,6 +43,10 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } +func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { + return 0 +} + func TestConnectionUnixSocket(t *testing.T) { th := &testHandler{}