diff --git a/go/mysql/client.go b/go/mysql/client.go index b45da1922d6..3ef28912914 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -166,6 +166,28 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) { return c, nil } +// Ping implements mysql ping command. +func (c *Conn) Ping() error { + // This is a new command, need to reset the sequence. + c.sequence = 0 + + if err := c.writePacket([]byte{ComPing}); err != nil { + return NewSQLError(CRServerGone, SSUnknownSQLState, "%v", err) + } + data, err := c.readEphemeralPacket() + if err != nil { + return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err) + } + defer c.recycleReadPacket() + switch data[0] { + case OKPacket: + return nil + case ErrPacket: + return ParseErrorPacket(data) + } + return fmt.Errorf("unexpected packet type: %d", data[0]) +} + // parseCharacterSet parses the provided character set. // Returns SQLError(CRCantReadCharset) if it can't. func parseCharacterSet(cs string) (uint8, error) { diff --git a/go/mysql/constants.go b/go/mysql/constants.go index afe203847be..dc961154734 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -440,6 +440,9 @@ const ( // SSHandshakeError is ER_HANDSHAKE_ERROR SSHandshakeError = "08S01" + // SSServerShutdown is ER_SERVER_SHUTDOWN + SSServerShutdown = "08S01" + // SSDataTooLong is ER_DATA_TOO_LONG SSDataTooLong = "22001" diff --git a/go/mysql/server.go b/go/mysql/server.go index 4a022707b7e..6c00b1656f5 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -27,6 +27,7 @@ import ( "vitess.io/vitess/go/netutil" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" + "vitess.io/vitess/go/sync2" "vitess.io/vitess/go/tb" "vitess.io/vitess/go/vt/log" ) @@ -132,6 +133,9 @@ type Listener struct { // connReadBufferSize is size of buffer for reads from underlying connection. // Reads are unbuffered if it's <=0. connReadBufferSize int + + // shutdown indicates that Shutdown method was called. + shutdown sync2.AtomicBool } // NewFromListener creares a new mysql listener from an existing net.Listener @@ -472,11 +476,18 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } case ComPing: - // No payload to that one, just return OKPacket. c.recycleReadPacket() - if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { - log.Errorf("Error writing ComPing result to %s: %v", c, err) - return + // Return error if listener was shut down and OK otherwise + if l.isShutdown() { + if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil { + log.Errorf("Error writing ComPing error to %s: %v", c, err) + return + } + } else { + if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { + log.Errorf("Error writing ComPing result to %s: %v", c, err) + return + } } case ComSetOption: if operation, ok := c.parseComSetOption(data); ok { @@ -514,11 +525,23 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } } -// Close stops the listener, and closes all connections. +// Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed. func (l *Listener) Close() { l.listener.Close() } +// Shutdown closes listener and fails any Ping requests from existing connections. +// This can be used for graceful shutdown, to let clients know that they should reconnect to another server. +func (l *Listener) Shutdown() { + if l.shutdown.CompareAndSwap(false, true) { + l.Close() + } +} + +func (l *Listener) isShutdown() bool { + return l.shutdown.Get() +} + // writeHandshakeV10 writes the Initial Handshake Packet, server side. // It returns the salt data. func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, enableTLS bool) ([]byte, error) { diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 57004ce6926..ddf4c01b943 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -1087,3 +1087,60 @@ func binaryPath(root, binary string) (string, error) { return "", fmt.Errorf("%s not found in any of %s/{%s}", binary, root, strings.Join(subdirs, ",")) } + +func TestListenerShutdown(t *testing.T) { + th := &testHandler{} + authServer := NewAuthServerStatic() + authServer.Entries["user1"] = []*AuthServerStaticEntry{{ + Password: "password1", + UserData: "userData1", + }} + l, err := NewListener("tcp", ":0", authServer, th, 0, 0) + if err != nil { + t.Fatalf("NewListener failed: %v", err) + } + defer l.Close() + go l.Accept() + + host, port := getHostPort(t, l.Addr()) + + // Setup the right parameters. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "password1", + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := Connect(ctx, params) + if err != nil { + t.Fatalf("Can't connect to listener: %v", err) + } + + if err := conn.Ping(); err != nil { + t.Fatalf("Ping failed: %v", err) + } + + l.Shutdown() + + if err := conn.Ping(); err != nil { + sqlErr, ok := err.(*SQLError) + if !ok { + t.Fatalf("Wrong error type: %T", err) + } + if sqlErr.Number() != ERServerShutdown { + t.Fatalf("Unexpected sql error code: %d", sqlErr.Number()) + } + if sqlErr.SQLState() != SSServerShutdown { + t.Fatalf("Unexpected error sql state: %s", sqlErr.SQLState()) + } + if sqlErr.Message != "Server shutdown in progress" { + t.Fatalf("Unexpected error message: %s", sqlErr.Message) + } + } else { + t.Fatalf("Ping should fail after shutdown") + } +}