Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,16 @@ func (c *Conn) parseInitialHandshakePacket(data []byte) (uint32, []byte, error)
if !ok {
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "parseInitialHandshakePacket: packet has no protocol version")
}

// Server is allowed to immediately send ERR packet
if pver == ErrPacket {
errorCode, pos, _ := readUint16(data, pos)
// Normally there would be a 1-byte sql_state_marker field and a 5-byte
// sql_state field here, but docs say these will not be present in this case.
errorMsg, pos, _ := readEOFString(data, pos)
return 0, nil, NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "immediate error from server errorCode=%v errorMsg=%v", errorCode, errorMsg)
}

if pver != protocolVersion {
return 0, nil, NewSQLError(CRVersionError, SSUnknownSQLState, "bad protocol version: %v", pver)
}
Expand Down
8 changes: 8 additions & 0 deletions go/mysql/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ func lenNullString(value string) int {
return len(value) + 1
}

func lenEOFString(value string) int {
return len(value)
}

func writeNullString(data []byte, pos int, value string) int {
pos += copy(data[pos:], value)
data[pos] = 0
Expand Down Expand Up @@ -180,6 +184,10 @@ func readNullString(data []byte, pos int) (string, int, bool) {
return string(data[pos : pos+end]), pos + end + 1, true
}

func readEOFString(data []byte, pos int) (string, int, bool) {
return string(data[pos:]), len(data) - pos, true
}

func readUint16(data []byte, pos int) (uint16, int, bool) {
if pos+1 >= len(data) {
return 0, 0, false
Expand Down
26 changes: 20 additions & 6 deletions go/mysql/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,21 +190,25 @@ func TestEncString(t *testing.T) {
value string
lenEncoded []byte
nullEncoded []byte
eofEncoded []byte
}{
{
"",
[]byte{0x00},
[]byte{0x00},
[]byte{},
},
{
"a",
[]byte{0x01, 'a'},
[]byte{'a', 0x00},
[]byte{'a'},
},
{
"0123456789",
[]byte{0x0a, '0', '1', '2', '3', '4', '5', '6', '7', '8', '9'},
[]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 0x00},
[]byte{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9'},
},
}
for _, test := range tests {
Expand All @@ -220,6 +224,11 @@ func TestEncString(t *testing.T) {
t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.nullEncoded), test.value)
}

// Check lenEOFString
if got := lenEOFString(test.value); got != len(test.eofEncoded) {
t.Errorf("lenNullString returned %v but expected %v for %v", got, len(test.eofEncoded), test.value)
}

// Check successful encoding.
data := make([]byte, len(test.lenEncoded))
pos := writeLenEncString(data, 0, test.value)
Expand Down Expand Up @@ -319,16 +328,21 @@ func TestEncString(t *testing.T) {
}

// EOF encoded tests.
// We use the nullEncoded value, removing the 0 at the end.

// Check successful encoding.
data = make([]byte, len(test.nullEncoded)-1)
data = make([]byte, len(test.eofEncoded))
pos = writeEOFString(data, 0, test.value)
if pos != len(test.nullEncoded)-1 {
t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.nullEncoded)-1)
if pos != len(test.eofEncoded) {
t.Errorf("unexpected pos %v after writeEOFString(%v), expected %v", pos, test.value, len(test.eofEncoded))
}
if !bytes.Equal(data, test.nullEncoded[:len(test.nullEncoded)-1]) {
t.Errorf("unexpected nullEncoded value for %v, got %v expected %v", test.value, data, test.nullEncoded)
if !bytes.Equal(data, test.eofEncoded[:len(test.eofEncoded)]) {
t.Errorf("unexpected eofEncoded value for %v, got %v expected %v", test.value, data, test.eofEncoded)
}

// Check successful decoding.
got, pos, ok = readEOFString(test.eofEncoded, 0)
if !ok || got != test.value || pos != len(test.eofEncoded) {
t.Errorf("readEOFString returned %v/%v/%v but expected %v/%v/%v", got, pos, ok, test.value, len(test.eofEncoded), true)
}
}
}