Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ func (c *Conn) handleComStmtReset(data []byte) bool {
}

func (c *Conn) handleComStmtSendLongData(data []byte) bool {
stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data)
stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data)
c.recycleReadPacket()
if !ok {
err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data)
Expand All @@ -903,9 +903,6 @@ func (c *Conn) handleComStmtSendLongData(data []byte) bool {
return c.writeErrorPacketFromErrorAndLog(err)
}

chunk := make([]byte, len(chunkData))
copy(chunk, chunkData)

key := fmt.Sprintf("v%d", paramID+1)
if val, ok := prepare.BindVars[key]; ok {
val.Value = append(val.Value, chunk...)
Expand Down
141 changes: 141 additions & 0 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package mysql

import (
"bytes"
"context"
crypto_rand "crypto/rand"
"encoding/binary"
"fmt"
"math/rand"
"net"
Expand All @@ -28,6 +30,7 @@ import (
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -73,6 +76,7 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) {
// Create a Conn on both sides.
cConn := newConn(clientConn)
sConn := newConn(serverConn)
sConn.PrepareData = map[uint32]*PrepareData{}

return listener, sConn, cConn
}
Expand Down Expand Up @@ -500,3 +504,140 @@ func (m mockAddress) String() string {
}

var _ net.Addr = (*mockAddress)(nil)

var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")

func randSeq(n int) string {
b := make([]rune, n)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return string(b)
}

func TestPrepareAndExecute(t *testing.T) {
// this test starts a lot of clients that all send prepared statement parameter values
// and check that the handler received the correct input
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
for i := 0; i < 1000; i++ {
startGoRoutine(ctx, t, randSeq(i))
}

for {
select {
case <-ctx.Done():
return
default:
if t.Failed() {
return
}
}
}
}

func startGoRoutine(ctx context.Context, t *testing.T, s string) {
go func(longData string) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

sql := "SELECT * FROM test WHERE id = ?"
mockData := preparePacket(t, sql)

err := cConn.writePacket(mockData)
require.NoError(t, err)

handler := &testRun{
t: t,
expParamCounts: 1,
expQuery: sql,
expStmtID: 1,
}

ok := sConn.handleNextCommand(handler)
require.True(t, ok, "oh noes")

resp, err := cConn.ReadPacket()
require.NoError(t, err)
require.EqualValues(t, 0, resp[0])

for count := 0; ; count++ {
select {
case <-ctx.Done():
return
default:
}
cConn.sequence = 0
longDataPacket := createSendLongDataPacket(sConn.StatementID, 0, []byte(longData))
err = cConn.writePacket(longDataPacket)
assert.NoError(t, err)

assert.True(t, sConn.handleNextCommand(handler))
data := sConn.PrepareData[sConn.StatementID]
assert.NotNil(t, data)
variable := data.BindVars["v1"]
assert.NotNil(t, variable, fmt.Sprintf("%#v", data.BindVars))
assert.Equalf(t, []byte(longData), variable.Value[len(longData)*count:], "failed at: %d", count)
}
}(s)
}

func createSendLongDataPacket(stmtID uint32, paramID uint16, data []byte) []byte {
stmtIDBinary := make([]byte, 4)
binary.LittleEndian.PutUint32(stmtIDBinary, stmtID)

paramIDBinary := make([]byte, 2)
binary.LittleEndian.PutUint16(paramIDBinary, paramID)

packet := []byte{0, 0, 0, 0, ComStmtSendLongData}
packet = append(packet, stmtIDBinary...) // append stmt ID
packet = append(packet, paramIDBinary...) // append param ID
packet = append(packet, data...) // append data
return packet
}

type testRun struct {
t *testing.T
expParamCounts int
expQuery string
expStmtID int
}

func (t testRun) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error {
panic("implement me")
}

func (t testRun) NewConnection(c *Conn) {
panic("implement me")
}

func (t testRun) ConnectionClosed(c *Conn) {
panic("implement me")
}

func (t testRun) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error {
panic("implement me")
}

func (t testRun) ComPrepare(c *Conn, query string, bv map[string]*querypb.BindVariable) ([]*querypb.Field, error) {
assert.Equal(t.t, t.expQuery, query)
assert.EqualValues(t.t, t.expStmtID, c.StatementID)
assert.NotNil(t.t, c.PrepareData[c.StatementID])
assert.EqualValues(t.t, t.expParamCounts, c.PrepareData[c.StatementID].ParamsCount)
assert.Len(t.t, c.PrepareData, int(c.PrepareData[c.StatementID].ParamsCount))
return nil, nil
}

func (t testRun) WarningCount(c *Conn) uint16 {
return 0
}

func (t testRun) ComResetConnection(c *Conn) {
panic("implement me")
}

var _ Handler = (*testRun)(nil)
6 changes: 5 additions & 1 deletion go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,11 @@ func (c *Conn) parseComStmtSendLongData(data []byte) (uint32, uint16, []byte, bo
return 0, 0, nil, false
}

return statementID, paramID, data[pos:], true
chunkData := data[pos:]
chunk := make([]byte, len(chunkData))
copy(chunk, chunkData)

return statementID, paramID, chunk, true
}

func (c *Conn) parseComStmtClose(data []byte) (uint32, bool) {
Expand Down
4 changes: 2 additions & 2 deletions go/mysql/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
)

// Utility function to write sql query as packets to test parseComPrepare
func MockQueryPackets(t *testing.T, query string) []byte {
func preparePacket(t *testing.T, query string) []byte {
data := make([]byte, len(query)+1+packetHeaderSize)
// Not sure if it makes a difference
pos := packetHeaderSize
Expand Down Expand Up @@ -127,7 +127,7 @@ func TestComStmtPrepare(t *testing.T) {
}()

sql := "select * from test_table where id = ?"
mockData := MockQueryPackets(t, sql)
mockData := preparePacket(t, sql)

if err := cConn.writePacket(mockData); err != nil {
t.Fatalf("writePacket failed: %v", err)
Expand Down