From 4f0210a1b466847dbc095f42b06028b3b68662f7 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 1 Oct 2020 10:33:13 +0200 Subject: [PATCH 1/5] Fix error around breaking of multistatements When a multistatement query is received, any errors should abort the execution of the remaining queries. The `execQuery` and `handleNextCommand` were returning an error, but not actually using the error value - just checking if it was nil or not. We need to be able to know on the outside of `execQuery` if an error occured and if it was an error we need to close the connection for or if it was a simple execution error. Signed-off-by: Andres Taylor --- go/mysql/conn.go | 313 ++++++++++++++++++++--------------------- go/mysql/conn_test.go | 70 +++++++++ go/mysql/query_test.go | 6 +- go/mysql/server.go | 4 +- 4 files changed, 231 insertions(+), 162 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 81d7c64cec3..97f467733e8 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -758,7 +758,7 @@ func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error { // handleNextCommand is called in the server loop to process // incoming packets. -func (c *Conn) handleNextCommand(handler Handler) error { +func (c *Conn) handleNextCommand(handler Handler) bool { c.sequence = 0 data, err := c.readEphemeralPacket() if err != nil { @@ -766,78 +766,72 @@ func (c *Conn) handleNextCommand(handler Handler) error { if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { log.Errorf("Error reading packet from %s: %v", c, err) } - return err + return false } switch data[0] { case ComQuit: c.recycleReadPacket() - return errors.New("ComQuit") + return false case ComInitDB: db := c.parseComInitDB(data) c.recycleReadPacket() - if err := c.execQuery("use "+sqlescape.EscapeID(db), handler, false); err != nil { - return err - } + res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false) + return res == execSuccess // TODO: we shouldn't drop the connection if the user is asking for the wrong db case ComQuery: - err := func() error { - c.startWriterBuffering() - defer func() { - if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) - } - }() + c.startWriterBuffering() + defer func() { + if err := c.endWriterBuffering(); err != nil { + log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + } + }() - queryStart := time.Now() - query := c.parseComQuery(data) - c.recycleReadPacket() - - var queries []string - if c.Capabilities&CapabilityClientMultiStatements != 0 { - queries, err = sqlparser.SplitStatementToPieces(query) - if err != nil { - log.Errorf("Conn %v: Error splitting query: %v", c, err) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Conn %v: Error writing query error: %v", c, werr) - return werr - } + queryStart := time.Now() + query := c.parseComQuery(data) + c.recycleReadPacket() + + var queries []string + if c.Capabilities&CapabilityClientMultiStatements != 0 { + queries, err = sqlparser.SplitStatementToPieces(query) + if err != nil { + log.Errorf("Conn %v: Error splitting query: %v", c, err) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Conn %v: Error writing query error: %v", c, werr) + return false } - } else { - queries = []string{query} } - for index, sql := range queries { - more := false - if index != len(queries)-1 { - more = true - } - if err := c.execQuery(sql, handler, more); err != nil { - return err - } + } else { + queries = []string{query} + } + for index, sql := range queries { + more := false + if index != len(queries)-1 { + more = true + } + res := c.execQuery(sql, handler, more) + if res != execSuccess { + return res != connErr } - - timings.Record(queryTimingKey, queryStart) - - return nil - }() - if err != nil { - return err } + timings.Record(queryTimingKey, queryStart) + case ComPing: c.recycleReadPacket() // Return error if listener was shut down and OK otherwise if c.listener.isShutdown() { if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil { log.Errorf("Error writing ComPing error to %s: %v", c, err) - return err + return false } } else { if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { log.Errorf("Error writing ComPing result to %s: %v", c, err) - return err + return false } } + case ComSetOption: operation, ok := c.parseComSetOption(data) c.recycleReadPacket() @@ -851,20 +845,21 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data) if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { log.Errorf("Error writing error packet to client: %v", err) - return err + return false } } if err := c.writeEndResult(false, 0, 0, 0); err != nil { log.Errorf("Error writeEndResult error %v ", err) - return err + return false } } else { log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data) if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { log.Errorf("Error writing error packet to client: %v", err) - return err + return false } } + case ComPrepare: query := c.parseComPrepare(data) c.recycleReadPacket() @@ -877,7 +872,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Errorf("Conn %v: Error writing query error: %v", c, werr) - return werr + return false } } } else { @@ -885,7 +880,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { } if len(queries) != 1 { - return fmt.Errorf("can not prepare multiple statements") + return false // TODO: do we really want to close the connection because of this? } // Popoulate PrepareData @@ -901,7 +896,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Errorf("Conn %v: Error writing prepared statement error: %v", c, werr) - return werr + return false } } @@ -936,120 +931,116 @@ func (c *Conn) handleNextCommand(handler Handler) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return werr + return false } - return nil + return true } if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { - return err + log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err) + return false } case ComStmtExecute: - err := func() error { - c.startWriterBuffering() + c.startWriterBuffering() + defer func() { + if err := c.endWriterBuffering(); err != nil { + log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + } + }() + queryStart := time.Now() + stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) + c.recycleReadPacket() + + if stmtID != uint32(0) { defer func() { - if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) - } + // Allocate a new bindvar map every time since VTGate.Execute() mutates it. + prepare := c.PrepareData[stmtID] + prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount) }() - queryStart := time.Now() - stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) - c.recycleReadPacket() - - if stmtID != uint32(0) { - defer func() { - // Allocate a new bindvar map every time since VTGate.Execute() mutates it. - prepare := c.PrepareData[stmtID] - prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount) - }() - } + } - if err != nil { - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return werr - } - return nil + if err != nil { + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false } + return true + } - fieldSent := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false - prepare := c.PrepareData[stmtID] - err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } - - if !fieldSent { - fieldSent = true - - if len(qr.Fields) == 0 { - sendFinished = true - // We should not send any more packets after this. - return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) - } - if err := c.writeFields(qr); err != nil { - return err - } - } - - return c.writeBinaryRows(qr) - }) + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + prepare := c.PrepareData[stmtID] + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } - // If no field was sent, we expect an error. if !fieldSent { - // This is just a failsafe. Should never happen. - if err == nil || err == io.EOF { - err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) - } - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Error writing query error to %s: %v", c, werr) - return werr + fieldSent = true + + if len(qr.Fields) == 0 { + sendFinished = true + // We should not send any more packets after this. + return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) + if err := c.writeFields(qr); err != nil { return err } + } - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return err - } - } + return c.writeBinaryRows(qr) + }) + + // If no field was sent, we expect an error. + if !fieldSent { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Error writing query error to %s: %v", c, werr) + return false + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return false } - timings.Record(queryTimingKey, queryStart) - return nil - }() - if err != nil { - return err + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return false + } + } } + + timings.Record(queryTimingKey, queryStart) + case ComStmtSendLongData: stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) c.recycleReadPacket() if !ok { err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) log.Error(err.Error()) - return err + return false // TODO: really break here? } prepare, ok := c.PrepareData[stmtID] if !ok { err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) log.Error(err.Error()) - return err + return false // TODO: really break here? } if prepare.BindVars == nil || @@ -1057,7 +1048,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { paramID >= prepare.ParamsCount { err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) log.Error(err.Error()) - return err + return false // TODO: really break here? } chunk := make([]byte, len(chunkData)) @@ -1082,7 +1073,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { log.Error("Error writing error packet to client: %v", err) - return err + return false } } @@ -1091,7 +1082,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) if err := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); err != nil { log.Error("Error writing error packet to client: %v", err) - return err + return false } } @@ -1103,7 +1094,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) - return err + return false } case ComResetConnection: @@ -1122,14 +1113,22 @@ func (c *Conn) handleNextCommand(handler Handler) error { c.recycleReadPacket() if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]); err != nil { log.Errorf("Error writing error packet to %s: %s", c, err) - return err + return false } } - return nil + return true } -func (c *Conn) execQuery(query string, handler Handler, more bool) error { +type execResult byte + +const ( + execSuccess execResult = iota + execErr + connErr +) + +func (c *Conn) execQuery(query string, handler Handler, more bool) execResult { fieldSent := false // sendFinished is set if the response should just be an OK packet. sendFinished := false @@ -1175,28 +1174,28 @@ func (c *Conn) execQuery(query string, handler Handler, more bool) error { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Errorf("Error writing query error to %s: %v", c, werr) - return werr - } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) - return err + return connErr } + return execErr + } + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return connErr + } - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return err - } + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(more, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return connErr } } - return nil + return execSuccess } // diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index d8ab1a8526a..ba88fbf2810 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -19,12 +19,18 @@ package mysql import ( "bytes" crypto_rand "crypto/rand" + "fmt" "math/rand" "net" "reflect" + "runtime/debug" "sync" "testing" "time" + + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) { @@ -288,3 +294,67 @@ func TestEOFOrLengthEncodedIntFuzz(t *testing.T) { } } } + +func TestMultiStatementStopsOnError(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.WriteComQuery("select 1;select 2") + require.NoError(t, err) + + // this handler will return an error on the first run, and fail the test if it's run more times + handler := &singleRun{t: t, err: fmt.Errorf("execution failed")} + res := sConn.handleNextCommand(handler) + require.True(t, res, res, "we should not break the connection because of execution errors") + + data, err := cConn.ReadPacket() + require.NoError(t, err) + require.NotEmpty(t, data) + require.EqualValues(t, data[0], ErrPacket) // we should see the error here +} + +type singleRun struct { + hasRun bool + t *testing.T + err error +} + +func (h *singleRun) NewConnection(*Conn) { + panic("implement me") +} + +func (h *singleRun) ConnectionClosed(*Conn) { + panic("implement me") +} + +func (h *singleRun) ComQuery(*Conn, string, func(*sqltypes.Result) error) error { + if h.hasRun { + debug.PrintStack() + h.t.Fatal("don't do this!") + } + h.hasRun = true + return h.err +} + +func (h *singleRun) ComPrepare(*Conn, string, map[string]*querypb.BindVariable) ([]*querypb.Field, error) { + panic("implement me") +} + +func (h *singleRun) ComStmtExecute(*Conn, *PrepareData, func(*sqltypes.Result) error) error { + panic("implement me") +} + +func (h *singleRun) WarningCount(*Conn) uint16 { + return 0 +} + +func (h *singleRun) ComResetConnection(*Conn) { + panic("implement me") +} + +var _ Handler = (*singleRun)(nil) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 8273759c38d..105271dc6d2 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -628,9 +628,9 @@ func checkQueryInternal(t *testing.T, query string, sConn, cConn *Conn, result * } for i := 0; i < count; i++ { - err := sConn.handleNextCommand(&handler) - if err != nil { - t.Fatalf("error handling command: %v", err) + kontinue := sConn.handleNextCommand(&handler) + if !kontinue { + t.Fatalf("error handling command: %d", i) } } diff --git a/go/mysql/server.go b/go/mysql/server.go index bc3ecc8dea7..c44f90c3c71 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -470,8 +470,8 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } for { - err := c.handleNextCommand(l.handler) - if err != nil { + kontinue := c.handleNextCommand(l.handler) + if !kontinue { return } } From f0221814144b84f4c0b0b33b148c142571a43bd3 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Thu, 1 Oct 2020 12:28:17 +0200 Subject: [PATCH 2/5] debugging failing test Signed-off-by: Andres Taylor --- go/mysql/server_test.go | 72 +++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 42 deletions(-) diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index d84f6db8118..fb97c03c3eb 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -644,10 +644,8 @@ func TestServer(t *testing.T) { // Run a 'panic' command, other side should panic, recover and // close the connection. - output, ok = runMysql(t, params, "panic") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "panic") + require.Error(t, err) if !strings.Contains(output, "ERROR 2013 (HY000)") || !strings.Contains(output, "Lost connection to MySQL server during query") { t.Errorf("Unexpected output for 'panic'") @@ -666,10 +664,8 @@ func TestServer(t *testing.T) { } // Run a 'select rows' command with results. - output, ok = runMysql(t, params, "select rows") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "select rows") + require.NoError(t, err) if !strings.Contains(output, "nice name") || !strings.Contains(output, "nicer name") || !strings.Contains(output, "2 rows in set") { @@ -681,10 +677,8 @@ func TestServer(t *testing.T) { // Run a 'select rows' command with warnings th.SetWarnings(13) - output, ok = runMysql(t, params, "select rows") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "select rows") + require.NoError(t, err) if !strings.Contains(output, "nice name") || !strings.Contains(output, "nicer name") || !strings.Contains(output, "2 rows in set") || @@ -696,39 +690,31 @@ func TestServer(t *testing.T) { // If there's an error after streaming has started, // we should get a 2013 th.SetErr(NewSQLError(ERUnknownComError, SSUnknownComError, "forced error after send")) - output, ok = runMysql(t, params, "error after send") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "error after send") + require.Error(t, err) if !strings.Contains(output, "ERROR 2013 (HY000)") || !strings.Contains(output, "Lost connection to MySQL server during query") { t.Errorf("Unexpected output for 'panic'") } // Run an 'insert' command, no rows, but rows affected. - output, ok = runMysql(t, params, "insert") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "insert") + require.NoError(t, err) if !strings.Contains(output, "Query OK, 123 rows affected") { t.Errorf("Unexpected output for 'insert'") } // Run a 'schema echo' command, to make sure db name is right. params.DbName = "XXXfancyXXX" - output, ok = runMysql(t, params, "schema echo") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "schema echo") + require.NoError(t, err) if !strings.Contains(output, params.DbName) { t.Errorf("Unexpected output for 'schema echo'") } // Sanity check: make sure this didn't go through SSL - output, ok = runMysql(t, params, "ssl echo") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "ssl echo") + require.NoError(t, err) if !strings.Contains(output, "ssl_flag") || !strings.Contains(output, "OFF") || !strings.Contains(output, "1 row in set") { @@ -736,10 +722,8 @@ func TestServer(t *testing.T) { } // UserData check: checks the server user data is correct. - output, ok = runMysql(t, params, "userData echo") - if !ok { - t.Fatalf("mysql failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "userData echo") + require.NoError(t, err) if !strings.Contains(output, "user1") || !strings.Contains(output, "user_data") || !strings.Contains(output, "userData1") { @@ -748,10 +732,8 @@ func TestServer(t *testing.T) { // Permissions check: check a bad password is rejected. params.Pass = "bad" - output, ok = runMysql(t, params, "select rows") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "select rows") + require.Error(t, err) if !strings.Contains(output, "1045") || !strings.Contains(output, "28000") || !strings.Contains(output, "Access denied") { @@ -761,10 +743,8 @@ func TestServer(t *testing.T) { // Permissions check: check an unknown user is rejected. params.Pass = "password1" params.Uname = "user2" - output, ok = runMysql(t, params, "select rows") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } + output, err = runMysqlWithErr(t, params, "select rows") + require.Error(t, err) if !strings.Contains(output, "1045") || !strings.Contains(output, "28000") || !strings.Contains(output, "Access denied") { @@ -1219,6 +1199,14 @@ const enableCleartextPluginPrefix = "enable-cleartext-plugin: " // runMysql forks a mysql command line process connecting to the provided server. func runMysql(t *testing.T, params *ConnParams, command string) (string, bool) { + output, err := runMysqlWithErr(t, params, command) + if err != nil { + return output, false + } + return output, true + +} +func runMysqlWithErr(t *testing.T, params *ConnParams, command string) (string, error) { dir, err := vtenv.VtMysqlRoot() if err != nil { t.Fatalf("vtenv.VtMysqlRoot failed: %v", err) @@ -1277,9 +1265,9 @@ func runMysql(t *testing.T, params *ConnParams, command string) (string, bool) { out, err := cmd.CombinedOutput() output := string(out) if err != nil { - return output, false + return output, err } - return output, true + return output, nil } // binaryPath does a limited path lookup for a command, From 7a530bb5e5c865d56b1a6911ca31e041e0ce3229 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 2 Oct 2020 11:45:21 +0200 Subject: [PATCH 3/5] clean up tests to make them less fragile Signed-off-by: Andres Taylor --- go/mysql/server_test.go | 156 +++++++++++++++++++++------------------- 1 file changed, 83 insertions(+), 73 deletions(-) diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index fb97c03c3eb..9b493ada6b0 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -580,9 +580,7 @@ func TestServer(t *testing.T) { }} defer authServer.close() l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) - if err != nil { - t.Fatalf("NewListener failed: %v", err) - } + require.NoError(t, err) l.SlowConnectWarnThreshold.Set(time.Duration(time.Nanosecond * 1)) defer l.Close() go l.Accept() @@ -597,83 +595,16 @@ func TestServer(t *testing.T) { Pass: "password1", } - initialTimingCounts := timings.Counts() - initialConnAccept := connAccept.Get() - initialConnSlow := connSlow.Get() - initialconnRefuse := connRefuse.Get() - - // Run an 'error' command. - th.SetErr(NewSQLError(ERUnknownComError, SSUnknownComError, "forced query error")) - output, ok := runMysql(t, params, "error") - if ok { - t.Fatalf("mysql should have failed: %v", output) - } - if !strings.Contains(output, "ERROR 1047 (08S01)") || - !strings.Contains(output, "forced query error") { - t.Errorf("Unexpected output for 'error': %v", output) - } - if connCount.Get() != 0 { - t.Errorf("Expected ConnCount=0, got %d", connCount.Get()) - } - if connAccept.Get()-initialConnAccept != 1 { - t.Errorf("Expected ConnAccept delta=1, got %d", connAccept.Get()-initialConnAccept) - } - if connSlow.Get()-initialConnSlow != 1 { - t.Errorf("Expected ConnSlow delta=1, got %d", connSlow.Get()-initialConnSlow) - } - if connRefuse.Get()-initialconnRefuse != 0 { - t.Errorf("Expected connRefuse delta=0, got %d", connRefuse.Get()-initialconnRefuse) - } - - expectedTimingDeltas := map[string]int64{ - "All": 2, - connectTimingKey: 1, - queryTimingKey: 1, - } - gotTimingCounts := timings.Counts() - for key, got := range gotTimingCounts { - expected := expectedTimingDeltas[key] - delta := got - initialTimingCounts[key] - if delta < expected { - t.Errorf("Expected Timing count delta %s should be >= %d, got %d", key, expected, delta) - } - } - - // Set the slow connect threshold to something high that we don't expect to trigger - l.SlowConnectWarnThreshold.Set(time.Duration(time.Second * 1)) - - // Run a 'panic' command, other side should panic, recover and - // close the connection. - output, err = runMysqlWithErr(t, params, "panic") - require.Error(t, err) - if !strings.Contains(output, "ERROR 2013 (HY000)") || - !strings.Contains(output, "Lost connection to MySQL server during query") { - t.Errorf("Unexpected output for 'panic'") - } - if connCount.Get() != 0 { - t.Errorf("Expected ConnCount=0, got %d", connCount.Get()) - } - if connAccept.Get()-initialConnAccept != 2 { - t.Errorf("Expected ConnAccept delta=2, got %d", connAccept.Get()-initialConnAccept) - } - if connSlow.Get()-initialConnSlow != 1 { - t.Errorf("Expected ConnSlow delta=1, got %d", connSlow.Get()-initialConnSlow) - } - if connRefuse.Get()-initialconnRefuse != 0 { - t.Errorf("Expected connRefuse delta=0, got %d", connRefuse.Get()-initialconnRefuse) - } - // Run a 'select rows' command with results. - output, err = runMysqlWithErr(t, params, "select rows") + output, err := runMysqlWithErr(t, params, "select rows") require.NoError(t, err) + if !strings.Contains(output, "nice name") || !strings.Contains(output, "nicer name") || !strings.Contains(output, "2 rows in set") { t.Errorf("Unexpected output for 'select rows'") } - if strings.Contains(output, "warnings") { - t.Errorf("Unexpected warnings in 'select rows'") - } + assert.NotContains(t, output, "warnings") // Run a 'select rows' command with warnings th.SetWarnings(13) @@ -756,6 +687,85 @@ func TestServer(t *testing.T) { // time.Sleep(60 * time.Minute) } +func TestServerStats(t *testing.T) { + th := &testHandler{} + + authServer := NewAuthServerStatic("", "", 0) + authServer.entries["user1"] = []*AuthServerStaticEntry{{ + Password: "password1", + UserData: "userData1", + }} + defer authServer.close() + l, err := NewListener("tcp", ":0", authServer, th, 0, 0, false) + if err != nil { + t.Fatalf("NewListener failed: %v", err) + } + l.SlowConnectWarnThreshold.Set(time.Duration(time.Nanosecond * 1)) + defer l.Close() + go l.Accept() + + host, port := getHostPort(t, l.Addr()) + + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "password1", + } + + timings.Reset() + connAccept.Reset() + connCount.Reset() + connSlow.Reset() + connRefuse.Reset() + + // Run an 'error' command. + th.SetErr(NewSQLError(ERUnknownComError, SSUnknownComError, "forced query error")) + output, ok := runMysql(t, params, "error") + if ok { + t.Fatalf("mysql should have failed: %v", output) + } + if !strings.Contains(output, "ERROR 1047 (08S01)") || + !strings.Contains(output, "forced query error") { + t.Errorf("Unexpected output for 'error': %v", output) + } + assert.EqualValues(t, 0, connCount.Get(), "connCount") + assert.EqualValues(t, 1, connAccept.Get(), "connAccept") + assert.EqualValues(t, 1, connSlow.Get(), "connSlow") + assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse") + + expectedTimingDeltas := map[string]int64{ + "All": 2, + connectTimingKey: 1, + queryTimingKey: 1, + } + gotTimingCounts := timings.Counts() + for key, got := range gotTimingCounts { + expected := expectedTimingDeltas[key] + if got < expected { + t.Errorf("Expected Timing count delta %s should be >= %d, got %d", key, expected, got) + } + } + + // Set the slow connect threshold to something high that we don't expect to trigger + l.SlowConnectWarnThreshold.Set(time.Duration(time.Second * 1)) + + // Run a 'panic' command, other side should panic, recover and + // close the connection. + output, err = runMysqlWithErr(t, params, "panic") + require.Error(t, err) + if !strings.Contains(output, "ERROR 2013 (HY000)") || + !strings.Contains(output, "Lost connection to MySQL server during query") { + t.Errorf("Unexpected output for 'panic'") + } + + assert.EqualValues(t, 0, connCount.Get(), "connCount") + assert.EqualValues(t, 2, connAccept.Get(), "connAccept") + assert.EqualValues(t, 1, connSlow.Get(), "connSlow") + assert.EqualValues(t, 0, connRefuse.Get(), "connRefuse") +} + // TestClearTextServer creates a Server that needs clear text // passwords from the client. func TestClearTextServer(t *testing.T) { From 5cb6e5f0e02251d5923284ca854995c649140b2c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 2 Oct 2020 16:11:08 +0200 Subject: [PATCH 4/5] Fewer situations where vtgate drops connections Both when preparing and when executing prepared statements, Vitess was dropping the connection on most errors. This change makes it so Vitess instead returns an error packet but keeps the connection open. Signed-off-by: Andres Taylor --- go/mysql/conn.go | 45 +++++++++++++++++++++++++++++-------------- go/mysql/conn_test.go | 22 +++++++++++++++++++++ 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 97f467733e8..39fe94210d2 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -777,7 +777,7 @@ func (c *Conn) handleNextCommand(handler Handler) bool { db := c.parseComInitDB(data) c.recycleReadPacket() res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false) - return res == execSuccess // TODO: we shouldn't drop the connection if the user is asking for the wrong db + return res != connErr case ComQuery: c.startWriterBuffering() defer func() { @@ -875,14 +875,19 @@ func (c *Conn) handleNextCommand(handler Handler) bool { return false } } + if len(queries) != 1 { + log.Errorf("Conn %v: can not prepare multiple statements", c, err) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Conn %v: Error writing query error: %v", c, werr) + return false + } + return true + } } else { queries = []string{query} } - if len(queries) != 1 { - return false // TODO: do we really want to close the connection because of this? - } - // Popoulate PrepareData c.StatementID++ prepare := &PrepareData{ @@ -1031,24 +1036,36 @@ func (c *Conn) handleNextCommand(handler Handler) bool { stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) c.recycleReadPacket() if !ok { - err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) - log.Error(err.Error()) - return false // TODO: really break here? + err = fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false + } + return true } prepare, ok := c.PrepareData[stmtID] if !ok { - err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) - log.Error(err.Error()) - return false // TODO: really break here? + err = fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false + } + return true } if prepare.BindVars == nil || prepare.ParamsCount == uint16(0) || paramID >= prepare.ParamsCount { - err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) - log.Error(err.Error()) - return false // TODO: really break here? + err = fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return false + } + return true } chunk := make([]byte, len(chunkData)) diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index ba88fbf2810..aa5bd2eaf45 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -318,6 +318,28 @@ func TestMultiStatementStopsOnError(t *testing.T) { require.EqualValues(t, data[0], ErrPacket) // we should see the error here } +func TestInitDbAgainstWrongDbDoesNotDropConnection(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.writeComInitDB("database") + require.NoError(t, err) + + handler := &singleRun{t: t, err: fmt.Errorf("execution failed")} + res := sConn.handleNextCommand(handler) + require.True(t, res, "we should not break the connection because of execution errors") + + data, err := cConn.ReadPacket() + require.NoError(t, err) + require.NotEmpty(t, data) + require.EqualValues(t, data[0], ErrPacket) // we should see the error here +} + type singleRun struct { hasRun bool t *testing.T From 11a574b3e0aee417b4e22a66676d5ac3b9e48f1c Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 2 Oct 2020 16:47:41 +0200 Subject: [PATCH 5/5] style nit Signed-off-by: Andres Taylor --- go/mysql/conn.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 39fe94210d2..7a220168df4 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -190,6 +190,14 @@ type PrepareData struct { BindVars map[string]*querypb.BindVariable } +type execResult byte + +const ( + execSuccess execResult = iota + execErr + connErr +) + // bufPool is used to allocate and free buffers in an efficient way. var bufPool = bucketpool.New(connBufferSize, MaxPacketSize) @@ -1137,14 +1145,6 @@ func (c *Conn) handleNextCommand(handler Handler) bool { return true } -type execResult byte - -const ( - execSuccess execResult = iota - execErr - connErr -) - func (c *Conn) execQuery(query string, handler Handler, more bool) execResult { fieldSent := false // sendFinished is set if the response should just be an OK packet.