diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 5b77ac2e742..6f151a4b9f3 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -883,7 +883,7 @@ func (c *Conn) handleComStmtReset(data []byte) bool { } func (c *Conn) handleComStmtSendLongData(data []byte) bool { - stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) + stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data) c.recycleReadPacket() if !ok { err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) @@ -903,9 +903,6 @@ func (c *Conn) handleComStmtSendLongData(data []byte) bool { return c.writeErrorPacketFromErrorAndLog(err) } - chunk := make([]byte, len(chunkData)) - copy(chunk, chunkData) - key := fmt.Sprintf("v%d", paramID+1) if val, ok := prepare.BindVars[key]; ok { val.Value = append(val.Value, chunk...) diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index 9646ca229ca..0dbb040f14c 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -18,7 +18,9 @@ package mysql import ( "bytes" + "context" crypto_rand "crypto/rand" + "encoding/binary" "fmt" "math/rand" "net" @@ -28,6 +30,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" @@ -73,6 +76,7 @@ func createSocketPair(t *testing.T) (net.Listener, *Conn, *Conn) { // Create a Conn on both sides. cConn := newConn(clientConn) sConn := newConn(serverConn) + sConn.PrepareData = map[uint32]*PrepareData{} return listener, sConn, cConn } @@ -500,3 +504,140 @@ func (m mockAddress) String() string { } var _ net.Addr = (*mockAddress)(nil) + +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randSeq(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +func TestPrepareAndExecute(t *testing.T) { + // this test starts a lot of clients that all send prepared statement parameter values + // and check that the handler received the correct input + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + for i := 0; i < 1000; i++ { + startGoRoutine(ctx, t, randSeq(i)) + } + + for { + select { + case <-ctx.Done(): + return + default: + if t.Failed() { + return + } + } + } +} + +func startGoRoutine(ctx context.Context, t *testing.T, s string) { + go func(longData string) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + sql := "SELECT * FROM test WHERE id = ?" + mockData := preparePacket(t, sql) + + err := cConn.writePacket(mockData) + require.NoError(t, err) + + handler := &testRun{ + t: t, + expParamCounts: 1, + expQuery: sql, + expStmtID: 1, + } + + ok := sConn.handleNextCommand(handler) + require.True(t, ok, "oh noes") + + resp, err := cConn.ReadPacket() + require.NoError(t, err) + require.EqualValues(t, 0, resp[0]) + + for count := 0; ; count++ { + select { + case <-ctx.Done(): + return + default: + } + cConn.sequence = 0 + longDataPacket := createSendLongDataPacket(sConn.StatementID, 0, []byte(longData)) + err = cConn.writePacket(longDataPacket) + assert.NoError(t, err) + + assert.True(t, sConn.handleNextCommand(handler)) + data := sConn.PrepareData[sConn.StatementID] + assert.NotNil(t, data) + variable := data.BindVars["v1"] + assert.NotNil(t, variable, fmt.Sprintf("%#v", data.BindVars)) + assert.Equalf(t, []byte(longData), variable.Value[len(longData)*count:], "failed at: %d", count) + } + }(s) +} + +func createSendLongDataPacket(stmtID uint32, paramID uint16, data []byte) []byte { + stmtIDBinary := make([]byte, 4) + binary.LittleEndian.PutUint32(stmtIDBinary, stmtID) + + paramIDBinary := make([]byte, 2) + binary.LittleEndian.PutUint16(paramIDBinary, paramID) + + packet := []byte{0, 0, 0, 0, ComStmtSendLongData} + packet = append(packet, stmtIDBinary...) // append stmt ID + packet = append(packet, paramIDBinary...) // append param ID + packet = append(packet, data...) // append data + return packet +} + +type testRun struct { + t *testing.T + expParamCounts int + expQuery string + expStmtID int +} + +func (t testRun) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { + panic("implement me") +} + +func (t testRun) NewConnection(c *Conn) { + panic("implement me") +} + +func (t testRun) ConnectionClosed(c *Conn) { + panic("implement me") +} + +func (t testRun) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { + panic("implement me") +} + +func (t testRun) ComPrepare(c *Conn, query string, bv map[string]*querypb.BindVariable) ([]*querypb.Field, error) { + assert.Equal(t.t, t.expQuery, query) + assert.EqualValues(t.t, t.expStmtID, c.StatementID) + assert.NotNil(t.t, c.PrepareData[c.StatementID]) + assert.EqualValues(t.t, t.expParamCounts, c.PrepareData[c.StatementID].ParamsCount) + assert.Len(t.t, c.PrepareData, int(c.PrepareData[c.StatementID].ParamsCount)) + return nil, nil +} + +func (t testRun) WarningCount(c *Conn) uint16 { + return 0 +} + +func (t testRun) ComResetConnection(c *Conn) { + panic("implement me") +} + +var _ Handler = (*testRun)(nil) diff --git a/go/mysql/query.go b/go/mysql/query.go index 94449457b10..8e301aec98b 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -845,7 +845,11 @@ func (c *Conn) parseComStmtSendLongData(data []byte) (uint32, uint16, []byte, bo return 0, 0, nil, false } - return statementID, paramID, data[pos:], true + chunkData := data[pos:] + chunk := make([]byte, len(chunkData)) + copy(chunk, chunkData) + + return statementID, paramID, chunk, true } func (c *Conn) parseComStmtClose(data []byte) (uint32, bool) { diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 105271dc6d2..5eb617d3f2d 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -30,7 +30,7 @@ import ( ) // Utility function to write sql query as packets to test parseComPrepare -func MockQueryPackets(t *testing.T, query string) []byte { +func preparePacket(t *testing.T, query string) []byte { data := make([]byte, len(query)+1+packetHeaderSize) // Not sure if it makes a difference pos := packetHeaderSize @@ -127,7 +127,7 @@ func TestComStmtPrepare(t *testing.T) { }() sql := "select * from test_table where id = ?" - mockData := MockQueryPackets(t, sql) + mockData := preparePacket(t, sql) if err := cConn.writePacket(mockData); err != nil { t.Fatalf("writePacket failed: %v", err)