diff --git a/server/handshake_resp.go b/server/handshake_resp.go index 8fc4ff7b7..762843865 100644 --- a/server/handshake_resp.go +++ b/server/handshake_resp.go @@ -60,8 +60,13 @@ func (c *Conn) readFirstPart() ([]byte, int, error) { return c.decodeFirstPart(data) } -func (c *Conn) decodeFirstPart(data []byte) ([]byte, int, error) { - pos := 0 +func (c *Conn) decodeFirstPart(data []byte) (newData []byte, pos int, err error) { + // prevent 'panic: runtime error: index out of range' error + defer func() { + if recover() != nil { + err = NewDefaultError(ER_HANDSHAKE_ERROR) + } + }() // check CLIENT_PROTOCOL_41 if uint32(binary.LittleEndian.Uint16(data[:2]))&CLIENT_PROTOCOL_41 == 0 { diff --git a/server/handshake_resp_test.go b/server/handshake_resp_test.go index c75b9d180..5f6cd616f 100644 --- a/server/handshake_resp_test.go +++ b/server/handshake_resp_test.go @@ -28,10 +28,17 @@ func TestReadAuthData(t *testing.T) { } func TestDecodeFirstPart(t *testing.T) { - data := []byte{141, 174, 255, 1, 0, 0, 0, 1, 8} - c := &Conn{} + // test out of range index returns 'bad handshake' error + _, _, err := c.decodeFirstPart([]byte{141, 174}) + if err == nil || err.Error() != "ERROR 1043 (08S01): Bad handshake" { + t.Fatal("expected error, got nil") + } + + // test good index position + data := []byte{141, 174, 255, 1, 0, 0, 0, 1, 8} + result, pos, err := c.decodeFirstPart(data) if err != nil { t.Fatalf("expected nil error, got %v", err)