Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,28 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
return c, nil
}

// Ping implements mysql ping command.
func (c *Conn) Ping() error {
// This is a new command, need to reset the sequence.
c.sequence = 0

if err := c.writePacket([]byte{ComPing}); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, "%v", err)
}
data, err := c.readEphemeralPacket()
if err != nil {
return NewSQLError(CRServerLost, SSUnknownSQLState, "%v", err)
}
defer c.recycleReadPacket()
switch data[0] {
case OKPacket:
return nil
case ErrPacket:
return ParseErrorPacket(data)
}
return fmt.Errorf("unexpected packet type: %d", data[0])
}

// parseCharacterSet parses the provided character set.
// Returns SQLError(CRCantReadCharset) if it can't.
func parseCharacterSet(cs string) (uint8, error) {
Expand Down
3 changes: 3 additions & 0 deletions go/mysql/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ const (
// SSHandshakeError is ER_HANDSHAKE_ERROR
SSHandshakeError = "08S01"

// SSServerShutdown is ER_SERVER_SHUTDOWN
SSServerShutdown = "08S01"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the same # as above. is that right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, and there is one more above it. I'm not sure what does it mean, but that's state from the reference manual for that type of error.


// SSDataTooLong is ER_DATA_TOO_LONG
SSDataTooLong = "22001"

Expand Down
33 changes: 28 additions & 5 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"vitess.io/vitess/go/netutil"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/stats"
"vitess.io/vitess/go/sync2"
"vitess.io/vitess/go/tb"
"vitess.io/vitess/go/vt/log"
)
Expand Down Expand Up @@ -132,6 +133,9 @@ type Listener struct {
// connReadBufferSize is size of buffer for reads from underlying connection.
// Reads are unbuffered if it's <=0.
connReadBufferSize int

// shutdown indicates that Shutdown method was called.
shutdown sync2.AtomicBool
}

// NewFromListener creares a new mysql listener from an existing net.Listener
Expand Down Expand Up @@ -472,11 +476,18 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}

case ComPing:
// No payload to that one, just return OKPacket.
c.recycleReadPacket()
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Error writing ComPing result to %s: %v", c, err)
return
// Return error if listener was shut down and OK otherwise
if l.isShutdown() {
if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil {
log.Errorf("Error writing ComPing error to %s: %v", c, err)
return
}
} else {
if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Errorf("Error writing ComPing result to %s: %v", c, err)
return
}
}
case ComSetOption:
if operation, ok := c.parseComSetOption(data); ok {
Expand Down Expand Up @@ -514,11 +525,23 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}
}

// Close stops the listener, and closes all connections.
// Close stops the listener, which prevents accept of any new connections. Existing connections won't be closed.
func (l *Listener) Close() {
l.listener.Close()
}

// Shutdown closes listener and fails any Ping requests from existing connections.
// This can be used for graceful shutdown, to let clients know that they should reconnect to another server.
func (l *Listener) Shutdown() {
if l.shutdown.CompareAndSwap(false, true) {
l.Close()
}
}

func (l *Listener) isShutdown() bool {
return l.shutdown.Get()
}

// writeHandshakeV10 writes the Initial Handshake Packet, server side.
// It returns the salt data.
func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, enableTLS bool) ([]byte, error) {
Expand Down
57 changes: 57 additions & 0 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1087,3 +1087,60 @@ func binaryPath(root, binary string) (string, error) {
return "", fmt.Errorf("%s not found in any of %s/{%s}",
binary, root, strings.Join(subdirs, ","))
}

func TestListenerShutdown(t *testing.T) {
th := &testHandler{}
authServer := NewAuthServerStatic()
authServer.Entries["user1"] = []*AuthServerStaticEntry{{
Password: "password1",
UserData: "userData1",
}}
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
defer l.Close()
go l.Accept()

host, port := getHostPort(t, l.Addr())

// Setup the right parameters.
params := &ConnParams{
Host: host,
Port: port,
Uname: "user1",
Pass: "password1",
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

conn, err := Connect(ctx, params)
if err != nil {
t.Fatalf("Can't connect to listener: %v", err)
}

if err := conn.Ping(); err != nil {
t.Fatalf("Ping failed: %v", err)
}

l.Shutdown()

if err := conn.Ping(); err != nil {
sqlErr, ok := err.(*SQLError)
if !ok {
t.Fatalf("Wrong error type: %T", err)
}
if sqlErr.Number() != ERServerShutdown {
t.Fatalf("Unexpected sql error code: %d", sqlErr.Number())
}
if sqlErr.SQLState() != SSServerShutdown {
t.Fatalf("Unexpected error sql state: %s", sqlErr.SQLState())
}
if sqlErr.Message != "Server shutdown in progress" {
t.Fatalf("Unexpected error message: %s", sqlErr.Message)
}
} else {
t.Fatalf("Ping should fail after shutdown")
}
}