diff --git a/go/mysql/constants.go b/go/mysql/constants.go index dc961154734..4d1a530a861 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -114,9 +114,9 @@ const ( // Client supports plugin authentication. CapabilityClientPluginAuth = 1 << 19 - // CLIENT_CONNECT_ATTRS 1 << 20 + // CapabilityClientConnAttr is CLIENT_CONNECT_ATTRS // Permits connection attributes in Protocol::HandshakeResponse41. - // Not yet supported. + CapabilityClientConnAttr = 1 << 20 // CapabilityClientPluginAuthLenencClientData is CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA CapabilityClientPluginAuthLenencClientData = 1 << 21 diff --git a/go/mysql/server.go b/go/mysql/server.go index c21535be12f..f32977ae46a 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -470,7 +470,8 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en CapabilityClientMultiResults | CapabilityClientPluginAuth | CapabilityClientPluginAuthLenencClientData | - CapabilityClientDeprecateEOF + CapabilityClientDeprecateEOF | + CapabilityClientConnAttr if enableTLS { capabilities |= CapabilityClientSSL } @@ -669,11 +670,66 @@ func (l *Listener) parseClientHandshakePacket(c *Conn, firstTime bool, data []by authMethod = MysqlNativePassword } - // FIXME(alainjobart) Add CLIENT_CONNECT_ATTRS parsing if we need it. + // Decode connection attributes send by the client + if clientFlags&CapabilityClientConnAttr != 0 { + var attrs map[string]string + var err error + attrs, pos, err = parseConnAttrs(data, pos) + if err != nil { + return "", "", nil, err + } + log.Infof("Connection Attributes: %-v", attrs) + } return username, authMethod, authResponse, nil } +func parseConnAttrs(data []byte, pos int) (map[string]string, int, error) { + var attrLen uint64 + + attrLen, pos, ok := readLenEncInt(data, pos) + if !ok { + return nil, 0, fmt.Errorf("parseClientHandshakePacket: can't read connection attributes variable length") + } + + var attrLenRead uint64 + + attrs := make(map[string]string) + + for attrLenRead < attrLen { + var keyLen byte + keyLen, pos, ok = readByte(data, pos) + if !ok { + return nil, 0, fmt.Errorf("parseClientHandshakePacket: can't read connection attribute key length") + } + attrLenRead += uint64(keyLen) + 1 + + var connAttrKey []byte + connAttrKey, pos, ok = readBytesCopy(data, pos, int(keyLen)) + if !ok { + return nil, 0, fmt.Errorf("parseClientHandshakePacket: can't read connection attribute key") + } + + var valLen byte + valLen, pos, ok = readByte(data, pos) + if !ok { + return nil, 0, fmt.Errorf("parseClientHandshakePacket: can't read connection attribute value length") + } + attrLenRead += uint64(valLen) + 1 + + var connAttrVal []byte + connAttrVal, pos, ok = readBytesCopy(data, pos, int(valLen)) + if !ok { + return nil, 0, fmt.Errorf("parseClientHandshakePacket: can't read connection attribute value") + } + + attrs[string(connAttrKey[:])] = string(connAttrVal[:]) + } + + return attrs, pos, nil + +} + // writeAuthSwitchRequest writes an auth switch request packet. func (c *Conn) writeAuthSwitchRequest(pluginName string, pluginData []byte) error { length := 1 + // AuthSwitchRequestPacket diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 5fa1efe4f67..8895425a7cd 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -1269,3 +1269,39 @@ func TestListenerShutdown(t *testing.T) { t.Fatalf("Ping should fail after shutdown") } } + +func TestParseConnAttrs(t *testing.T) { + expected := map[string]string{ + "_client_version": "8.0.11", + "program_name": "mysql", + "_pid": "22850", + "_platform": "x86_64", + "_os": "linux-glibc2.12", + "_client_name": "libmysql", + } + + data := []byte{0x70, 0x04, 0x5f, 0x70, 0x69, 0x64, 0x05, 0x32, 0x32, 0x38, 0x35, 0x30, 0x09, 0x5f, 0x70, 0x6c, + 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x06, 0x78, 0x38, 0x36, 0x5f, 0x36, 0x34, 0x03, 0x5f, 0x6f, + 0x73, 0x0f, 0x6c, 0x69, 0x6e, 0x75, 0x78, 0x2d, 0x67, 0x6c, 0x69, 0x62, 0x63, 0x32, 0x2e, 0x31, + 0x32, 0x0c, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x08, 0x6c, + 0x69, 0x62, 0x6d, 0x79, 0x73, 0x71, 0x6c, 0x0f, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, + 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x06, 0x38, 0x2e, 0x30, 0x2e, 0x31, 0x31, 0x0c, 0x70, + 0x72, 0x6f, 0x67, 0x72, 0x61, 0x6d, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x05, 0x6d, 0x79, 0x73, 0x71, 0x6c} + + attrs, pos, err := parseConnAttrs(data, 0) + if err != nil { + t.Fatalf("Failed to read connection attributes: %v", err) + } + if pos != 113 { + t.Fatalf("Unexpeded pos after reading connection attributes: %d intead of 113", pos) + } + for k, v := range expected { + if val, ok := attrs[k]; ok { + if val != v { + t.Fatalf("Unexpected value found in attrs for key %s: got %s expected %s", k, val, v) + } + } else { + t.Fatalf("Error reading key %s from connection attributes: attrs: %-v", k, attrs) + } + } +}