diff --git a/go.mod b/go.mod index f7cac4419c..f6633d4833 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15 github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 - github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813 + github.com/dolthub/vitess v0.0.0-20240206204925-6acf16fa777c github.com/go-kit/kit v0.10.0 github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d github.com/gocraft/dbr/v2 v2.7.2 diff --git a/go.sum b/go.sum index b24c5b1100..63360ebc60 100644 --- a/go.sum +++ b/go.sum @@ -60,6 +60,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9X github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY= github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813 h1:tGwsoLAMFQ+7FDEyIWOIJ1Vc/nptbFi0Fh7SQahB8ro= github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= +github.com/dolthub/vitess v0.0.0-20240206204925-6acf16fa777c h1:Zt23BHsxvPHGfpHV9k/FcsHqWZjfybyQQux2OLpRni8= +github.com/dolthub/vitess v0.0.0-20240206204925-6acf16fa777c/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= diff --git a/server/extension.go b/server/extension.go index c71ff7e6c1..407d9bac0e 100644 --- a/server/extension.go +++ b/server/extension.go @@ -167,8 +167,8 @@ func (ih *interceptorHandler) WarningCount(c *mysql.Conn) uint16 { return ih.h.WarningCount(c) } -func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) { - ih.h.ComResetConnection(c) +func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) error { + return ih.h.ComResetConnection(c) } func (ih *interceptorHandler) ParserOptionsForConnection(c *mysql.Conn) (ast.ParserOptions, error) { diff --git a/server/golden/proxy.go b/server/golden/proxy.go index 9b45c40e8a..1b6728dc2a 100644 --- a/server/golden/proxy.go +++ b/server/golden/proxy.go @@ -149,8 +149,8 @@ func (h MySqlProxy) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, ca } // ComResetConnection implements mysql.Handler. -func (h MySqlProxy) ComResetConnection(c *mysql.Conn) { - return +func (h MySqlProxy) ComResetConnection(_ *mysql.Conn) error { + return nil } // ConnectionClosed implements mysql.Handler. diff --git a/server/golden/validator.go b/server/golden/validator.go index 506cff8c74..2c675c4c2e 100644 --- a/server/golden/validator.go +++ b/server/golden/validator.go @@ -81,8 +81,8 @@ func (v Validator) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, cal return fmt.Errorf("ComStmtExecute unsupported") } -func (v Validator) ComResetConnection(c *mysql.Conn) { - return +func (v Validator) ComResetConnection(_ *mysql.Conn) error { + return nil } // ConnectionClosed reports that a connection has been closed. diff --git a/server/handler.go b/server/handler.go index cab3df3ef2..a74a121814 100644 --- a/server/handler.go +++ b/server/handler.go @@ -15,6 +15,7 @@ package server import ( + "context" "encoding/base64" "fmt" "io" @@ -197,10 +198,32 @@ func (h *Handler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, call return err } -func (h *Handler) ComResetConnection(c *mysql.Conn) { +// ComResetConnection implements the mysql.Handler interface. +// +// This command resets the connection's session, clearing out any cached prepared statements, locks, user and +// session variables. The currently selected database is preserved. +// +// The COM_RESET command can be sent manually through the mysql client by issuing the "resetconnection" (or "\x") +// client command. +func (h *Handler) ComResetConnection(c *mysql.Conn) error { logrus.WithField("connectionId", c.ConnectionID).Debug("COM_RESET_CONNECTION command received") - // TODO: handle reset logic + // Grab the currently selected database name + s := h.sm.session(c) + db := s.GetCurrentDatabase() + + // Dispose of the connection's current session + h.maybeReleaseAllLocks(c) + h.e.CloseSession(c.ConnectionID) + + // Create a new session and set the current database + err := h.sm.NewSession(context.Background(), c) + if err != nil { + return err + } + s = h.sm.session(c) + s.SetCurrentDatabase(db) + return nil } func (h *Handler) ParserOptionsForConnection(c *mysql.Conn) (sqlparser.ParserOptions, error) { @@ -222,6 +245,14 @@ func (h *Handler) ConnectionClosed(c *mysql.Conn) { defer h.sm.RemoveConn(c) defer h.e.CloseSession(c.ConnectionID) + h.maybeReleaseAllLocks(c) + + logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).Infof("ConnectionClosed") +} + +// maybeReleaseAllLocks makes a best effort attempt to release all locks on the given connection. If the attempt fails, +// an error is logged but not returned. +func (h *Handler) maybeReleaseAllLocks(c *mysql.Conn) { if ctx, err := h.sm.NewContextWithQuery(c, ""); err != nil { logrus.Errorf("unable to release all locks on session close: %s", err) logrus.Errorf("unable to unlock tables on session close: %s", err) @@ -234,8 +265,6 @@ func (h *Handler) ConnectionClosed(c *mysql.Conn) { logrus.Errorf("unable to unlock tables on session close: %s", err) } } - - logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).Infof("ConnectionClosed") } func (h *Handler) ComMultiQuery( diff --git a/server/handler_test.go b/server/handler_test.go index a09ad44258..3d97ec66dc 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -212,6 +212,69 @@ func TestHandlerOutput(t *testing.T) { }) } +// TestHandlerComReset asserts that the Handler.ComResetConnection method correctly clears all session +// state (e.g. table locks, prepared statements, user variables, session variables), and keeps the current +// database selected. +func TestHandlerComResetConnection(t *testing.T) { + e, pro := setupMemDB(require.New(t)) + dummyConn := newConn(1) + dbFunc := pro.Database + + handler := &Handler{ + e: e, + sm: NewSessionManager( + testSessionBuilder(pro), + sql.NoopTracer, + dbFunc, + sql.NewMemoryManager(nil), + sqle.NewProcessList(), + "foo", + ), + } + handler.NewConnection(dummyConn) + handler.ComInitDB(dummyConn, "test") + + prepareData := &mysql.PrepareData{ + StatementID: 0, + PrepareStmt: "select 42 + ? from dual", + ParamsCount: 0, + ParamsType: nil, + ColumnNames: nil, + BindVars: map[string]*query.BindVariable{ + "v1": {Type: query.Type_INT8, Value: []byte("5")}, + }, + } + + // Create a prepared statement, a table lock, and a user var in the current session + _, err := handler.ComPrepare(dummyConn, prepareData.PrepareStmt, prepareData) + require.NoError(t, err) + _, cached := e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt) + require.True(t, cached) + err = handler.ComQuery(dummyConn, "SET @userVar = 42;", func(res *sqltypes.Result, more bool) error { + return nil + }) + require.NoError(t, err) + + // Reset the connection to clear all session state + err = handler.ComResetConnection(dummyConn) + require.NoError(t, err) + + // Assert that the session is clean – the selected database should not change, and all session state + // such as user vars, session vars, prepared statements, table locks, and temporary tables should be cleared. + err = handler.ComQuery(dummyConn, "SELECT database()", func(res *sqltypes.Result, more bool) error { + require.Equal(t, "test", res.Rows[0][0].ToString()) + return nil + }) + require.NoError(t, err) + _, cached = e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt) + require.False(t, cached) + err = handler.ComQuery(dummyConn, "SELECT @userVar;", func(res *sqltypes.Result, more bool) error { + require.True(t, res.Rows[0][0].IsNull()) + return nil + }) + require.NoError(t, err) +} + func TestHandlerComPrepare(t *testing.T) { e, pro := setupMemDB(require.New(t)) dummyConn := newConn(1) @@ -1026,9 +1089,6 @@ func testServer(t *testing.T, ready chan struct{}, port string, breakConn bool) func okTestServer(t *testing.T, ready chan struct{}, port string) { testServer(t, ready, port, false) } -func brokenTestServer(t *testing.T, ready chan struct{}, port string) { - testServer(t, ready, port, true) -} // This session builder is used as dummy mysql Conn is not complete and // causes panic when accessing remote address.