Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require (
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e
github.com/dolthub/jsonpath v0.0.2-0.20240201003050-392940944c15
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813
github.com/dolthub/vitess v0.0.0-20240206204925-6acf16fa777c
github.com/go-kit/kit v0.10.0
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
github.com/gocraft/dbr/v2 v2.7.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9X
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813 h1:tGwsoLAMFQ+7FDEyIWOIJ1Vc/nptbFi0Fh7SQahB8ro=
github.com/dolthub/vitess v0.0.0-20240205203605-9e6c6d650813/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20240206204925-6acf16fa777c h1:Zt23BHsxvPHGfpHV9k/FcsHqWZjfybyQQux2OLpRni8=
github.com/dolthub/vitess v0.0.0-20240206204925-6acf16fa777c/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
Expand Down
4 changes: 2 additions & 2 deletions server/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ func (ih *interceptorHandler) WarningCount(c *mysql.Conn) uint16 {
return ih.h.WarningCount(c)
}

func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) {
ih.h.ComResetConnection(c)
func (ih *interceptorHandler) ComResetConnection(c *mysql.Conn) error {
return ih.h.ComResetConnection(c)
}

func (ih *interceptorHandler) ParserOptionsForConnection(c *mysql.Conn) (ast.ParserOptions, error) {
Expand Down
4 changes: 2 additions & 2 deletions server/golden/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ func (h MySqlProxy) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, ca
}

// ComResetConnection implements mysql.Handler.
func (h MySqlProxy) ComResetConnection(c *mysql.Conn) {
return
func (h MySqlProxy) ComResetConnection(_ *mysql.Conn) error {
return nil
}

// ConnectionClosed implements mysql.Handler.
Expand Down
4 changes: 2 additions & 2 deletions server/golden/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ func (v Validator) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, cal
return fmt.Errorf("ComStmtExecute unsupported")
}

func (v Validator) ComResetConnection(c *mysql.Conn) {
return
func (v Validator) ComResetConnection(_ *mysql.Conn) error {
return nil
}

// ConnectionClosed reports that a connection has been closed.
Expand Down
37 changes: 33 additions & 4 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package server

import (
"context"
"encoding/base64"
"fmt"
"io"
Expand Down Expand Up @@ -197,10 +198,32 @@ func (h *Handler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, call
return err
}

func (h *Handler) ComResetConnection(c *mysql.Conn) {
// ComResetConnection implements the mysql.Handler interface.
//
// This command resets the connection's session, clearing out any cached prepared statements, locks, user and
// session variables. The currently selected database is preserved.
//
// The COM_RESET command can be sent manually through the mysql client by issuing the "resetconnection" (or "\x")
// client command.
func (h *Handler) ComResetConnection(c *mysql.Conn) error {
logrus.WithField("connectionId", c.ConnectionID).Debug("COM_RESET_CONNECTION command received")

// TODO: handle reset logic
// Grab the currently selected database name
s := h.sm.session(c)
db := s.GetCurrentDatabase()

// Dispose of the connection's current session
h.maybeReleaseAllLocks(c)
h.e.CloseSession(c.ConnectionID)

// Create a new session and set the current database
err := h.sm.NewSession(context.Background(), c)
if err != nil {
return err
}
s = h.sm.session(c)
s.SetCurrentDatabase(db)
return nil
}

func (h *Handler) ParserOptionsForConnection(c *mysql.Conn) (sqlparser.ParserOptions, error) {
Expand All @@ -222,6 +245,14 @@ func (h *Handler) ConnectionClosed(c *mysql.Conn) {
defer h.sm.RemoveConn(c)
defer h.e.CloseSession(c.ConnectionID)

h.maybeReleaseAllLocks(c)

logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).Infof("ConnectionClosed")
}

// maybeReleaseAllLocks makes a best effort attempt to release all locks on the given connection. If the attempt fails,
// an error is logged but not returned.
func (h *Handler) maybeReleaseAllLocks(c *mysql.Conn) {
if ctx, err := h.sm.NewContextWithQuery(c, ""); err != nil {
logrus.Errorf("unable to release all locks on session close: %s", err)
logrus.Errorf("unable to unlock tables on session close: %s", err)
Expand All @@ -234,8 +265,6 @@ func (h *Handler) ConnectionClosed(c *mysql.Conn) {
logrus.Errorf("unable to unlock tables on session close: %s", err)
}
}

logrus.WithField(sql.ConnectionIdLogField, c.ConnectionID).Infof("ConnectionClosed")
}

func (h *Handler) ComMultiQuery(
Expand Down
66 changes: 63 additions & 3 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,69 @@ func TestHandlerOutput(t *testing.T) {
})
}

// TestHandlerComReset asserts that the Handler.ComResetConnection method correctly clears all session
// state (e.g. table locks, prepared statements, user variables, session variables), and keeps the current
// database selected.
func TestHandlerComResetConnection(t *testing.T) {
e, pro := setupMemDB(require.New(t))
dummyConn := newConn(1)
dbFunc := pro.Database

handler := &Handler{
e: e,
sm: NewSessionManager(
testSessionBuilder(pro),
sql.NoopTracer,
dbFunc,
sql.NewMemoryManager(nil),
sqle.NewProcessList(),
"foo",
),
}
handler.NewConnection(dummyConn)
handler.ComInitDB(dummyConn, "test")

prepareData := &mysql.PrepareData{
StatementID: 0,
PrepareStmt: "select 42 + ? from dual",
ParamsCount: 0,
ParamsType: nil,
ColumnNames: nil,
BindVars: map[string]*query.BindVariable{
"v1": {Type: query.Type_INT8, Value: []byte("5")},
},
}

// Create a prepared statement, a table lock, and a user var in the current session
_, err := handler.ComPrepare(dummyConn, prepareData.PrepareStmt, prepareData)
require.NoError(t, err)
_, cached := e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt)
require.True(t, cached)
err = handler.ComQuery(dummyConn, "SET @userVar = 42;", func(res *sqltypes.Result, more bool) error {
return nil
})
require.NoError(t, err)

// Reset the connection to clear all session state
err = handler.ComResetConnection(dummyConn)
require.NoError(t, err)

// Assert that the session is clean – the selected database should not change, and all session state
// such as user vars, session vars, prepared statements, table locks, and temporary tables should be cleared.
err = handler.ComQuery(dummyConn, "SELECT database()", func(res *sqltypes.Result, more bool) error {
require.Equal(t, "test", res.Rows[0][0].ToString())
return nil
})
require.NoError(t, err)
_, cached = e.PreparedDataCache.GetCachedStmt(dummyConn.ConnectionID, prepareData.PrepareStmt)
require.False(t, cached)
err = handler.ComQuery(dummyConn, "SELECT @userVar;", func(res *sqltypes.Result, more bool) error {
require.True(t, res.Rows[0][0].IsNull())
return nil
})
require.NoError(t, err)
}

func TestHandlerComPrepare(t *testing.T) {
e, pro := setupMemDB(require.New(t))
dummyConn := newConn(1)
Expand Down Expand Up @@ -1026,9 +1089,6 @@ func testServer(t *testing.T, ready chan struct{}, port string, breakConn bool)
func okTestServer(t *testing.T, ready chan struct{}, port string) {
testServer(t, ready, port, false)
}
func brokenTestServer(t *testing.T, ready chan struct{}, port string) {
testServer(t, ready, port, true)
}

// This session builder is used as dummy mysql Conn is not complete and
// causes panic when accessing remote address.
Expand Down