diff --git a/go/mysql/conn_flaky_test.go b/go/mysql/conn_flaky_test.go index 01ce8c0839b..17d98c47620 100644 --- a/go/mysql/conn_flaky_test.go +++ b/go/mysql/conn_flaky_test.go @@ -826,6 +826,15 @@ func (t testRun) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) if strings.Contains(query, "panic") { panic("test panic attack!") } + if strings.Contains(query, "close before rows read") { + c.writeFields(selectRowsResult) + // We want to close the connection after the fields are written + // and read on the client. So we sleep for 100 milliseconds + time.Sleep(100 * time.Millisecond) + c.Close() + return nil + } + if strings.Contains(query, "twice") { callback(selectRowsResult) } diff --git a/go/mysql/query.go b/go/mysql/query.go index 7271d4462a1..681660ef66c 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -416,7 +416,7 @@ func (c *Conn) ReadQueryResult(maxrows int, wantfields bool) (*sqltypes.Result, for { data, err := c.readEphemeralPacket() if err != nil { - return nil, false, 0, err + return nil, false, 0, NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) } // TODO: harshit - the EOF packet is deprecated as of MySQL 5.7.5. diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 0a3689b719c..b2809a428bc 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -375,6 +375,32 @@ func TestComStmtClose(t *testing.T) { } } +// This test has been added to verify that IO errors in a connection lead to SQL Server lost errors +// So that we end up closing the connection higher up the stack and not reusing it. +// This test was added in response to a panic that was run into. +func TestSQLErrorOnServerClose(t *testing.T) { + // Create socket pair for the server and client + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.WriteComQuery("close before rows read") + require.NoError(t, err) + + handler := &testRun{t: t} + _ = sConn.handleNextCommand(handler) + + // From the server we will receive a field packet which the client will read + // At that point, if the server crashes and closes the connection. + // We should be getting a Connection lost error. + _, _, _, err = cConn.ReadQueryResult(100, true) + require.Error(t, err) + require.True(t, IsConnLostDuringQuery(err), err.Error()) +} + func TestQueries(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() {