diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 42b8215c86c..b8b4507be38 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -1038,6 +1038,18 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) return err } + + case ComResetConnection: + // Clean up and reset the connection + c.recycleReadPacket() + handler.ComResetConnection(c) + // Reset prepared statements + c.PrepareData = make(map[uint32]*PrepareData) + err = c.writeOKPacket(0, 0, 0, 0) + if err != nil { + c.writeErrorPacketFromError(err) + } + default: log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data) c.recycleReadPacket() diff --git a/go/mysql/constants.go b/go/mysql/constants.go index b1d4491f637..eaa2784dfb3 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -174,6 +174,9 @@ const ( // ComSetOption is COM_SET_OPTION ComSetOption = 0x1b + // ComResetConnection is COM_RESET_CONNECTION + ComResetConnection = 0x1f + // ComBinlogDumpGTID is COM_BINLOG_DUMP_GTID. ComBinlogDumpGTID = 0x1e diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 2d8f079dd02..eb0fc94d5b8 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -440,6 +440,11 @@ func (db *DB) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback return nil } +// ComResetConnection is part of the mysql.Handler interface. +func (db *DB) ComResetConnection(c *mysql.Conn) { + +} + // // Methods to add expected queries and results. // diff --git a/go/mysql/server.go b/go/mysql/server.go index 5de219ff476..9b7dee2de18 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -109,6 +109,8 @@ type Handler interface { // ComQuery callback if the result does not contain any fields, // or after the last ComQuery call completes. WarningCount(c *Conn) uint16 + + ComResetConnection(c *Conn) } // Listener is the MySQL server protocol listener. diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 49456128c9f..c2ddb9cf608 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -180,6 +180,10 @@ func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback fu return nil } +func (th *testHandler) ComResetConnection(c *Conn) { + +} + func (th *testHandler) WarningCount(c *Conn) uint16 { return th.warnings } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index ea5e16abdd5..eb0a5a98fdc 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -84,6 +84,20 @@ func newVtgateHandler(vtg *VTGate) *vtgateHandler { func (vh *vtgateHandler) NewConnection(c *mysql.Conn) { } +func (vh *vtgateHandler) ComResetConnection(c *mysql.Conn) { + ctx := context.Background() + session, _ := c.ClientData.(*vtgatepb.Session) + if session != nil { + if session.InTransaction { + defer atomic.AddInt32(&busyConnections, -1) + } + _, _, err := vh.vtg.Execute(ctx, session, "rollback", make(map[string]*querypb.BindVariable)) + if err != nil { + log.Errorf("Error happened in transaction rollback: %v", err) + } + } +} + func (vh *vtgateHandler) ConnectionClosed(c *mysql.Conn) { // Rollback if there is an ongoing transaction. Ignore error. var ctx context.Context diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 1a4ee4e8a6c..16f1b616f6f 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -51,6 +51,10 @@ func (th *testHandler) ComPrepare(c *mysql.Conn, q string) ([]*querypb.Field, er return nil, nil } +func (th *testHandler) ComResetConnection(c *mysql.Conn) { + +} + func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { return nil } diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index a2593b95af1..26d393108b0 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -144,6 +144,10 @@ func (mh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData return nil } +func (mh *proxyHandler) ComResetConnection(c *mysql.Conn) { + +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener diff --git a/go/vt/vtqueryserver/plugin_mysql_server_test.go b/go/vt/vtqueryserver/plugin_mysql_server_test.go index 6c538b80a37..7bf13959093 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server_test.go +++ b/go/vt/vtqueryserver/plugin_mysql_server_test.go @@ -52,6 +52,10 @@ func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, return nil } +func (th *testHandler) ComResetConnection(c *mysql.Conn) { + +} + func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { return 0 } diff --git a/test/prepared_statement_test.py b/test/prepared_statement_test.py index e4b7ca89f42..63f05042732 100755 --- a/test/prepared_statement_test.py +++ b/test/prepared_statement_test.py @@ -213,7 +213,8 @@ def test_prepared_statements(self): utils.VtGate(mysql_server=True).start( extra_args=['-mysql_auth_server_impl', 'static', '-mysql_server_query_timeout', '1s', - '-mysql_auth_server_static_file', mysql_auth_server_static]) + '-mysql_auth_server_static_file', mysql_auth_server_static, + "-mysql_server_version", '8.0.16-7']) # We use gethostbyname('localhost') so we don't presume # of the IP format (travis is only IP v4, really). params = dict(host=socket.gethostbyname('localhost'), @@ -306,6 +307,15 @@ def test_prepared_statements(self): if res[0] != 1: self.fail("Delete failed") cursor.close() - + + # Reseting the connection + conn.cmd_reset_connection() + cursor = conn.cursor(cursor_class=MySQLCursorPrepared) + cursor.execute('select * from vt_prepare_stmt_test where id = %s', (1,)) + result = cursor.fetchone() + # Should fail since we cleared PreparedData inside the connection + with self.assertRaises(TypeError): + empty_val = result[-2] + if __name__ == '__main__': utils.main()