diff --git a/go/mysql/client.go b/go/mysql/client.go index 34dcacbc0f2..ee9ff368fda 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -373,6 +373,16 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error) if !ok { return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no protocol version") } + + // Server is allowed to immediately send ERR packet + if pver == ErrPacket { + errorCode, pos, _ := readUint16(data, pos) + // Normally there would be a 1-byte sql_state_marker field and a 5-byte + // sql_state field here, but docs say these will not be present in this case. + errorMsg, pos, _ := readEOFString(data, pos) + return 0, nil, NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "immediate error from server errorCode=%v errorMsg=%v", errorCode, errorMsg) + } + if pver != protocolVersion { return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "bad protocol version: %v", pver) } diff --git a/go/mysql/encoding.go b/go/mysql/encoding.go index 018da8a8f9b..7926c5bd498 100644 --- a/go/mysql/encoding.go +++ b/go/mysql/encoding.go @@ -80,6 +80,10 @@ func lenNullString(value string) int { return len(value) + 1 } +func lenEOFString(value string) int { + return len(value) +} + func writeNullString(data []byte, pos int, value string) int { pos += copy(data[pos:], value) data[pos] = 0 @@ -180,6 +184,10 @@ func readNullString(data []byte, pos int) (string, int, bool) { return string(data[pos : pos+end]), pos + end + 1, true } +func readEOFString(data []byte, pos int) (string, int, bool) { + return string(data[pos:]), len(data) - pos, true +} + func readUint16(data []byte, pos int) (uint16, int, bool) { if pos+1 >= len(data) { return 0, 0, false diff --git a/go/mysql/encoding_test.go b/go/mysql/encoding_test.go index 7f56f583c86..8d1bb804936 100644 --- a/go/mysql/encoding_test.go +++ b/go/mysql/encoding_test.go @@ -190,21 +190,25 @@ func TestEncString(t *testing.T) { value string lenEncoded []byte nullEncoded []byte + eofEncoded []byte }{ { "", []byte{0x00}, []byte{0x00}, + []byte{}, }, { "a", []byte{0x01, 'a'}, []byte{'a', 0x00}, + []byte{'a'}, }, { "0123456789", []byte{0x0a, '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}, []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 0x00}, + []byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'}, }, } for _, test := range tests { @@ -220,6 +224,11 @@ func TestEncString(t *testing.T) { t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.nullEncoded), test.value) } + // Check lenEOFString + if got := lenEOFString(test.value); got != len(test.eofEncoded) { + t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.eofEncoded), test.value) + } + // Check successful encoding. data := make([]byte, len(test.lenEncoded)) pos := writeLenEncString(data, 0, test.value) @@ -319,16 +328,21 @@ func TestEncString(t *testing.T) { } // EOF encoded tests. - // We use the nullEncoded value, removing the 0 at the end. // Check successful encoding. - data = make([]byte, len(test.nullEncoded)-1) + data = make([]byte, len(test.eofEncoded)) pos = writeEOFString(data, 0, test.value) - if pos != len(test.nullEncoded)-1 { - t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.nullEncoded)-1) + if pos != len(test.eofEncoded) { + t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.eofEncoded)) } - if !bytes.Equal(data, test.nullEncoded[:len(test.nullEncoded)-1]) { - t.Errorf("unexpected nullEncoded value for %v, got %v expected %v", test.value, data, test.nullEncoded) + if !bytes.Equal(data, test.eofEncoded[:len(test.eofEncoded)]) { + t.Errorf("unexpected eofEncoded value for %v, got %v expected %v", test.value, data, test.eofEncoded) + } + + // Check successful decoding. + got, pos, ok = readEOFString(test.eofEncoded, 0) + if !ok || got != test.value || pos != len(test.eofEncoded) { + t.Errorf("readEOFString returned %v/%v/%v but expected %v/%v/%v", got, pos, ok, test.value, len(test.eofEncoded), true) } } }