From 30cda05138a521f0ac3bf730d9b27bb9ca6ada06 Mon Sep 17 00:00:00 2001 From: dcadevil Date: Tue, 15 Jan 2019 18:49:03 +0800 Subject: [PATCH 01/32] support for the ComPrepare Signed-off-by: dcadevil --- go/mysql/conn.go | 81 +++++++++++- go/mysql/constants.go | 18 +++ go/mysql/fakesqldb/server.go | 5 + go/mysql/query.go | 73 +++++++++++ go/mysql/server.go | 4 + go/vt/vtgate/executor.go | 140 +++++++++++++++++++++ go/vt/vtgate/plugin_mysql_server.go | 88 +++++++++++++ go/vt/vtgate/vtgate.go | 28 +++++ go/vt/vtqueryserver/plugin_mysql_server.go | 4 + 9 files changed, 438 insertions(+), 3 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 2cf2ab99786..c75dca2645f 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -153,6 +153,21 @@ type Conn struct { // currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively. // It can be allocated from bufPool or heap and should be recycled in the same manner. currentEphemeralBuffer *[]byte + + StatementID uint32 + + PrepareData map[uint32]*prepareData +} + +// prepareData is a buffer used for store prepare statement meta data +type prepareData struct { + StatementID uint32 + PrepareStmt string + ParsedStmt *sqlparser.Statement + ParamsCount uint16 + ParamsType []int32 + ColumnNames []string + BindVars map[string]*querypb.BindVariable } // bufPool is used to allocate and free buffers in an efficient way. @@ -178,9 +193,10 @@ func newConn(conn net.Conn) *Conn { // size for reads. func newServerConn(conn net.Conn, listener *Listener) *Conn { c := &Conn{ - conn: conn, - listener: listener, - closed: sync2.NewAtomicBool(false), + conn: conn, + listener: listener, + closed: sync2.NewAtomicBool(false), + PrepareData: make(map[uint32]*prepareData), } if listener.connReadBufferSize > 0 { c.bufferedReader = bufio.NewReaderSize(conn, listener.connReadBufferSize) @@ -797,6 +813,65 @@ func (c *Conn) handleNextCommand(handler Handler) error { return err } } + case ComPrepare: + query := c.parseComPrepare(data) + c.recycleReadPacket() + + var queries []string + if c.Capabilities&CapabilityClientMultiStatements != 0 { + queries, err = sqlparser.SplitStatementToPieces(query) + if err != nil { + log.Errorf("Conn %v: Error splitting query: %v", c, err) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Conn %v: Error writing query error: %v", c, werr) + return werr + } + } + } else { + queries = []string{query} + } + + if len(queries) != 1 { + return fmt.Errorf("can not prepare multiple statements") + } + + c.StatementID++ + prepare := &prepareData{ + StatementID: c.StatementID, + PrepareStmt: queries[0], + } + + c.PrepareData[c.StatementID] = prepare + + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + err := handler.ComPrepare(c, queries[0], func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !fieldSent { + fieldSent = true + if err := c.writePrepare(qr, c.PrepareData[c.StatementID]); err != nil { + return err + } + } + + return nil + }) + if err != nil { + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return werr + } + + delete(c.PrepareData, c.StatementID) + return nil + } default: log.Errorf("Got unhandled packet from %s, returning error: %v", c, data) c.recycleReadPacket() diff --git a/go/mysql/constants.go b/go/mysql/constants.go index 4d1a530a861..b1d4491f637 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -153,6 +153,24 @@ const ( // ComBinlogDump is COM_BINLOG_DUMP. ComBinlogDump = 0x12 + // ComPrepare is COM_PREPARE. + ComPrepare = 0x16 + + // ComStmtExecute is COM_STMT_EXECUTE. + ComStmtExecute = 0x17 + + // ComStmtSendLongData is COM_STMT_SEND_LONG_DATA + ComStmtSendLongData = 0x18 + + // ComStmtClose is COM_STMT_CLOSE. + ComStmtClose = 0x19 + + // ComStmtReset is COM_STMT_RESET + ComStmtReset = 0x1a + + //ComStmtFetch is COM_STMT_FETCH + ComStmtFetch = 0x1c + // ComSetOption is COM_SET_OPTION ComSetOption = 0x1b diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 2272b17d7cb..1f4d465c36c 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -432,6 +432,11 @@ func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) { return entry.QueryResult, nil } +// ComPrepare is part of the mysql.Handler interface. +func (db *DB) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + return nil +} + // // Methods to add expected queries and results. // diff --git a/go/mysql/query.go b/go/mysql/query.go index ebaabb27a6a..75ded7e2b34 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -18,6 +18,7 @@ package mysql import ( "fmt" + "strings" "vitess.io/vitess/go/sqltypes" @@ -510,6 +511,10 @@ func (c *Conn) parseComSetOption(data []byte) (uint16, bool) { return val, ok } +func (c *Conn) parseComPrepare(data []byte) string { + return string(data[1:]) +} + func (c *Conn) parseComInitDB(data []byte) string { return string(data[1:]) } @@ -655,3 +660,71 @@ func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warn return nil } + +// writePrepare writes a prepare query response to the wire. +func (c *Conn) writePrepare(result *sqltypes.Result, prepare *prepareData) error { + paramsCount := prepare.ParamsCount + columnCount := 0 + if result != nil { + columnCount = len(result.Fields) + } + if columnCount > 0 { + prepare.ColumnNames = make([]string, columnCount) + } + + data := c.startEphemeralPacket(12) + pos := 0 + + pos = writeByte(data, pos, 0x00) + pos = writeUint32(data, pos, uint32(prepare.StatementID)) + pos = writeUint16(data, pos, uint16(columnCount)) + pos = writeUint16(data, pos, uint16(paramsCount)) + pos = writeByte(data, pos, 0x00) + pos = writeUint16(data, pos, 0x0000) + + if err := c.writeEphemeralPacket(); err != nil { + return err + } + + if paramsCount > 0 { + for i := uint16(0); i < paramsCount; i++ { + if err := c.writeColumnDefinition(&querypb.Field{ + Name: "?", + Type: sqltypes.VarBinary, + Charset: 63}); err != nil { + return err + } + } + + // Now send an EOF packet. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + // With CapabilityClientDeprecateEOF, we do not send this EOF. + if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { + return err + } + } + } + + if result != nil { + // Now send each Field. + for i, field := range result.Fields { + field.Name = strings.Replace(field.Name, "'?'", "?", -1) + prepare.ColumnNames[i] = field.Name + if err := c.writeColumnDefinition(field); err != nil { + return err + } + } + + if columnCount > 0 { + // Now send an EOF packet. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + // With CapabilityClientDeprecateEOF, we do not send this EOF. + if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { + return err + } + } + } + } + + return c.flush() +} diff --git a/go/mysql/server.go b/go/mysql/server.go index f7a6099ab2d..a80023ffc9b 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -93,6 +93,10 @@ type Handler interface { // hang on to the byte slice. ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error + // ComPrepare is called when a connection receives a prepared + // statement query. + ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error + // WarningCount is called at the end of each query to obtain // the value to be returned to the client in the EOF packet. // Note that this will be called either in the context of the diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index ef319ac7d0b..ce51be128b6 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1396,3 +1396,143 @@ func buildVarCharRow(values ...string) []sqltypes.Value { } return row } + +// Prepare executes a prepare statements. +func (e *Executor) Prepare(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) { + logStats := NewLogStats(ctx, method, sql, bindVars) + result, err = e.prepare(ctx, safeSession, sql, bindVars, logStats) + logStats.Error = err + + // The mysql plugin runs an implicit rollback whenever a connection closes. + // To avoid spamming the log with no-op rollback records, ignore it if + // it was a no-op record (i.e. didn't issue any queries) + if !(logStats.StmtType == "ROLLBACK" && logStats.ShardQueries == 0) { + logStats.Send() + } + return result, err +} + +func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *LogStats) (*sqltypes.Result, error) { + // Start an implicit transaction if necessary. + // TODO(sougou): deprecate legacyMode after all users are migrated out. + if !e.legacyAutocommit && !safeSession.Autocommit && !safeSession.InTransaction() { + if err := e.txConn.Begin(ctx, safeSession); err != nil { + return nil, err + } + } + + destKeyspace, destTabletType, dest, err := e.ParseDestinationTarget(safeSession.TargetString) + if err != nil { + return nil, err + } + + if safeSession.InTransaction() && destTabletType != topodatapb.TabletType_MASTER { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "transactions are supported only for master tablet types, current type: %v", destTabletType) + } + if bindVars == nil { + bindVars = make(map[string]*querypb.BindVariable) + } + + stmtType := sqlparser.Preview(sql) + logStats.StmtType = sqlparser.StmtType(stmtType) + + // Mysql warnings are scoped to the current session, but are + // cleared when a "non-diagnostic statement" is executed: + // https://dev.mysql.com/doc/refman/8.0/en/show-warnings.html + // + // To emulate this behavior, clear warnings from the session + // for all statements _except_ SHOW, so that SHOW WARNINGS + // can actually return them. + if stmtType != sqlparser.StmtShow { + safeSession.ClearWarnings() + } + + switch stmtType { + case sqlparser.StmtSelect: + return e.handlePrepare(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, logStats) + case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: + safeSession := safeSession + + // In legacy mode, we ignore autocommit settings. + if e.legacyAutocommit { + return &sqltypes.Result{}, nil + } + + mustCommit := false + if safeSession.Autocommit && !safeSession.InTransaction() { + mustCommit = true + if err := e.txConn.Begin(ctx, safeSession); err != nil { + return nil, err + } + // The defer acts as a failsafe. If commit was successful, + // the rollback will be a no-op. + defer e.txConn.Rollback(ctx, safeSession) + } + + // The SetAutocommitable flag should be same as mustCommit. + // If we started a transaction because of autocommit, then mustCommit + // will be true, which means that we can autocommit. If we were already + // in a transaction, it means that the app started it, or we are being + // called recursively. If so, we cannot autocommit because whatever we + // do is likely not final. + // The control flow is such that autocommitable can only be turned on + // at the beginning, but never after. + safeSession.SetAutocommitable(mustCommit) + + if mustCommit { + commitStart := time.Now() + if err = e.txConn.Commit(ctx, safeSession); err != nil { + return nil, err + } + logStats.CommitTime = time.Since(commitStart) + } + return &sqltypes.Result{}, nil + case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, + sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtComment: + return &sqltypes.Result{}, nil + case sqlparser.StmtShow: + return e.handleShow(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) + } + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unrecognized statement: %s", sql) +} + +func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, destKeyspace string, destTabletType topodatapb.TabletType, logStats *LogStats) (*sqltypes.Result, error) { + // V3 mode. + query, comments := sqlparser.SplitMarginComments(sql) + vcursor := newVCursorImpl(ctx, safeSession, destKeyspace, destTabletType, comments, e, logStats) + plan, err := e.getPlan( + vcursor, + query, + comments, + bindVars, + skipQueryPlanCache(safeSession), + logStats, + ) + execStart := time.Now() + logStats.PlanTime = execStart.Sub(logStats.StartTime) + + if err != nil { + logStats.Error = err + return nil, err + } + + qr, err := plan.Instructions.GetFields(vcursor, bindVars) + logStats.ExecuteTime = time.Since(execStart) + var errCount uint64 + if err != nil { + logStats.Error = err + errCount = 1 + } else { + logStats.RowsAffected = qr.RowsAffected + } + + // Check if there was partial DML execution. If so, rollback the transaction. + if err != nil && safeSession.InTransaction() && vcursor.hasPartialDML { + _ = e.txConn.Rollback(ctx, safeSession) + err = vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction rolled back due to partial DML execution: %v", err) + } + + plan.AddStats(1, time.Since(logStats.StartTime), uint64(logStats.ShardQueries), logStats.RowsAffected, errCount) + + return qr, err +} diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index e1c0e3408f0..f2e024cab8a 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "os" + "strings" "sync/atomic" "syscall" "time" @@ -33,6 +34,7 @@ import ( "vitess.io/vitess/go/vt/callinfo" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/servenv" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vttls" querypb "vitess.io/vitess/go/vt/proto/query" @@ -160,6 +162,92 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq return callback(result) } +// ComPrepare is the handler for command prepare. +func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } + + ctx = callinfo.MysqlCallInfo(ctx, c) + + // Fill in the ImmediateCallerID with the UserData returned by + // the AuthServer plugin for that user. If nothing was + // returned, use the User. This lets the plugin map a MySQL + // user used for authentication to a Vitess User used for + // Table ACLs and Vitess authentication in general. + im := c.UserData.Get() + ef := callerid.NewEffectiveCallerID( + c.User, /* principal: who */ + c.RemoteAddr().String(), /* component: running client process */ + "VTGate MySQL Connector" /* subcomponent: part of the client */) + ctx = callerid.NewContext(ctx, ef, im) + + session, _ := c.ClientData.(*vtgatepb.Session) + if session == nil { + session = &vtgatepb.Session{ + Options: &querypb.ExecuteOptions{ + IncludedFields: querypb.ExecuteOptions_ALL, + }, + Autocommit: true, + } + if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { + session.Options.ClientFoundRows = true + } + } + + if !session.InTransaction { + atomic.AddInt32(&busyConnections, 1) + } + defer func() { + if !session.InTransaction { + atomic.AddInt32(&busyConnections, -1) + } + }() + + if c.SchemaName != "" { + session.TargetString = c.SchemaName + } + + statement, err := sqlparser.ParseStrictDDL(query) + if err != nil { + err = mysql.NewSQLErrorFromError(err) + return err + } + + paramsCount := uint16(0) + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { + switch node := node.(type) { + case *sqlparser.SQLVal: + if strings.HasPrefix(string(node.Val), ":v") { + paramsCount++ + } + } + return true, nil + }, statement) + + prepare := c.PrepareData[c.StatementID] + prepare.ParsedStmt = &statement + + if paramsCount > 0 { + prepare.ParamsCount = paramsCount + prepare.ParamsType = make([]int32, paramsCount) + prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) + } + + session, result, err := vh.vtg.Prepare(ctx, session, query, make(map[string]*querypb.BindVariable)) + c.ClientData = session + err = mysql.NewSQLErrorFromError(err) + if err != nil { + return err + } + return callback(result) +} + func (vh *vtgateHandler) WarningCount(c *mysql.Conn) uint16 { session, _ := c.ClientData.(*vtgatepb.Session) if session != nil { diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 2b2e7b81d8c..f02f37433bb 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -832,6 +832,34 @@ func (vtg *VTGate) ResolveTransaction(ctx context.Context, dtid string) error { return formatError(vtg.txConn.Resolve(ctx, dtid)) } +// Prepare supports non-streaming prepare statement query with multi shards +func (vtg *VTGate) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (newSession *vtgatepb.Session, qr *sqltypes.Result, err error) { + // In this context, we don't care if we can't fully parse destination + destKeyspace, destTabletType, _, _ := vtg.executor.ParseDestinationTarget(session.TargetString) + statsKey := []string{"Execute", destKeyspace, topoproto.TabletTypeLString(destTabletType)} + defer vtg.timings.Record(statsKey, time.Now()) + + if bvErr := sqltypes.ValidateBindVariables(bindVariables); bvErr != nil { + err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", bvErr) + goto handleError + } + + qr, err = vtg.executor.Prepare(ctx, "Execute", NewSafeSession(session), sql, bindVariables) + if err == nil { + vtg.rowsReturned.Add(statsKey, int64(len(qr.Rows))) + return session, qr, nil + } + +handleError: + query := map[string]interface{}{ + "Sql": sql, + "BindVariables": bindVariables, + "Session": session, + } + err = recordAndAnnotateError(err, statsKey, query, vtg.logExecute) + return session, nil, err +} + // isKeyspaceRangeBasedSharded returns true if a keyspace is sharded // by range. This is true when there is a ShardingColumnType defined // in the SrvKeyspace (that is using the range-based sharding with the diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index 804cd207f8b..6bc1d82d6fd 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -136,6 +136,10 @@ func (mh *proxyHandler) WarningCount(c *mysql.Conn) uint16 { return 0 } +func (mh *proxyHandler) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + return nil +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener From 91f73b10669af120509030d27b6c728344008cea Mon Sep 17 00:00:00 2001 From: dcadevil Date: Wed, 16 Jan 2019 11:29:36 +0800 Subject: [PATCH 02/32] support for the MySQL prepare command protocol Signed-off-by: dcadevil --- go/mysql/conn.go | 168 ++++- go/mysql/fakesqldb/server.go | 5 + go/mysql/query.go | 755 ++++++++++++++++++++- go/mysql/server.go | 4 + go/vt/vtgate/plugin_mysql_server.go | 67 ++ go/vt/vtqueryserver/plugin_mysql_server.go | 4 + 6 files changed, 997 insertions(+), 6 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index c75dca2645f..bb52cc7a2d2 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -156,11 +156,11 @@ type Conn struct { StatementID uint32 - PrepareData map[uint32]*prepareData + PrepareData map[uint32]*PrepareData } -// prepareData is a buffer used for store prepare statement meta data -type prepareData struct { +// PrepareData is a buffer used for store prepare statement meta data +type PrepareData struct { StatementID uint32 PrepareStmt string ParsedStmt *sqlparser.Statement @@ -196,7 +196,7 @@ func newServerConn(conn net.Conn, listener *Listener) *Conn { conn: conn, listener: listener, closed: sync2.NewAtomicBool(false), - PrepareData: make(map[uint32]*prepareData), + PrepareData: make(map[uint32]*PrepareData), } if listener.connReadBufferSize > 0 { c.bufferedReader = bufio.NewReaderSize(conn, listener.connReadBufferSize) @@ -837,7 +837,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { } c.StatementID++ - prepare := &prepareData{ + prepare := &PrepareData{ StatementID: c.StatementID, PrepareStmt: queries[0], } @@ -872,6 +872,164 @@ func (c *Conn) handleNextCommand(handler Handler) error { delete(c.PrepareData, c.StatementID) return nil } + case ComStmtExecute: + queryStart := time.Now() + stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) + c.recycleReadPacket() + if err != nil { + if stmtID != uint32(0) { + prepare := c.PrepareData[stmtID] + if prepare.BindVars != nil { + for k := range prepare.BindVars { + prepare.BindVars[k] = nil + } + } + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return werr + } + return nil + } + + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + prepare := c.PrepareData[stmtID] + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !fieldSent { + fieldSent = true + + if len(qr.Fields) == 0 { + sendFinished = true + // We should not send any more packets after this. + return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) + } + if err := c.writeFields(qr); err != nil { + return err + } + } + + return c.writeBinaryRows(qr) + }) + + if prepare.BindVars != nil { + for k := range prepare.BindVars { + prepare.BindVars[k] = nil + } + } + + // If no field was sent, we expect an error. + if !fieldSent { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Error writing query error to %s: %v", c, werr) + return werr + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return err + } + + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return err + } + } + } + + timings.Record(queryTimingKey, queryStart) + case ComStmtSendLongData: + stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) + c.recycleReadPacket() + if !ok { + err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) + log.Error(err.Error()) + return err + } + + prepare, ok := c.PrepareData[stmtID] + if !ok { + err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) + log.Error(err.Error()) + return err + } + + if prepare.BindVars == nil || + prepare.ParamsCount == uint16(0) || + paramID >= prepare.ParamsCount { + err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) + log.Error(err.Error()) + return err + } + + chunk := make([]byte, len(chunkData)) + copy(chunk, chunkData) + + key := fmt.Sprintf("v%d", paramID+1) + if val, ok := prepare.BindVars[key]; ok { + val.Value = append(val.Value, chunk...) + } else { + v, err := sqltypes.InterfaceToValue(chunk) + if err != nil { + log.Error("build converted parameter value failed: %v", err) + return err + } + prepare.BindVars[key] = sqltypes.ValueBindVariable(v) + } + case ComStmtClose: + stmtID, ok := c.parseComStmtClose(data) + c.recycleReadPacket() + if ok { + delete(c.PrepareData, stmtID) + } + case ComStmtReset: + stmtID, ok := c.parseComStmtReset(data) + c.recycleReadPacket() + if !ok { + log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { + log.Error("Error writing error packet to client: %v", err) + return err + } + } + + prepare, ok := c.PrepareData[stmtID] + if !ok { + log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) + if err := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); err != nil { + log.Error("Error writing error packet to client: %v", err) + return err + } + } + + if prepare.BindVars != nil { + for k := range prepare.BindVars { + prepare.BindVars[k] = nil + } + } + + if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { + log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) + return err + } default: log.Errorf("Got unhandled packet from %s, returning error: %v", c, data) c.recycleReadPacket() diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 1f4d465c36c..9c169605ecd 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -437,6 +437,11 @@ func (db *DB) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Re return nil } +// ComStmtExecute is part of the mysql.Handler interface. +func (db *DB) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + // // Methods to add expected queries and results. // diff --git a/go/mysql/query.go b/go/mysql/query.go index 75ded7e2b34..0ad72813be9 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -18,6 +18,8 @@ package mysql import ( "fmt" + "math" + "strconv" "strings" "vitess.io/vitess/go/sqltypes" @@ -515,6 +517,343 @@ func (c *Conn) parseComPrepare(data []byte) string { return string(data[1:]) } +func (c *Conn) parseComStmtExecute(prepareData map[uint32]*PrepareData, data []byte) (uint32, byte, error) { + pos := 0 + payload := data[1:] + bitMap := make([]byte, 0) + + // statement ID + stmtID, pos, ok := readUint32(payload, 0) + if !ok { + return 0, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading statement ID failed") + } + prepare, ok := prepareData[stmtID] + if !ok { + return 0, 0, NewSQLError(CRCommandsOutOfSync, SSUnknownSQLState, "statement ID is not found from record") + } + + // cursor type flags + cursorType, pos, ok := readByte(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading cursor type flags failed") + } + + // iteration count + iterCount, pos, ok := readUint32(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading iteration count failed") + } + if iterCount != uint32(1) { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "iteration count is not equal to 1") + } + + if prepare.ParamsCount > 0 { + bitMap, pos, ok = readBytes(payload, pos, int((prepare.ParamsCount+7)/8)) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading NULL-bitmap failed") + } + } + + newParamsBoundFlag, pos, ok := readByte(payload, pos) + if newParamsBoundFlag == 0x01 { + var mysqlType, flags byte + for i := uint16(0); i < prepare.ParamsCount; i++ { + mysqlType, pos, ok = readByte(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter type failed") + } + + flags, pos, ok = readByte(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter flags failed") + } + + // convert MySQL type to internal type. + valType, err := sqltypes.MySQLToType(int64(mysqlType), int64(flags)) + if err != nil { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed: %v", mysqlType, flags, err) + } + + prepare.ParamsType[i] = int32(valType) + } + } + + for i := 0; i < len(prepare.ParamsType); i++ { + var val sqltypes.Value + parameterID := fmt.Sprintf("v%d", i+1) + if v, ok := prepare.BindVars[parameterID]; ok { + if v != nil { + continue + } + } + + if (bitMap[i/8] & (1 << uint(i%8))) > 0 { + val, pos, ok = c.parseStmtArgs(nil, sqltypes.Null, pos) + } else { + val, pos, ok = c.parseStmtArgs(payload, querypb.Type(prepare.ParamsType[i]), pos) + } + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "decoding parameter value failed: %v", prepare.ParamsType[i]) + } + + prepare.BindVars[parameterID] = sqltypes.ValueBindVariable(val) + } + + return stmtID, cursorType, nil +} + +func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.Value, int, bool) { + switch typ { + case sqltypes.Null: + return sqltypes.NULL, pos, true + case sqltypes.Int8: + val, pos, ok := readByte(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Uint8: + val, pos, ok := readByte(data, pos) + return sqltypes.NewUint64(uint64(val)), pos, ok + case sqltypes.Uint16: + val, pos, ok := readUint16(data, pos) + return sqltypes.NewUint64(uint64(val)), pos, ok + case sqltypes.Int16, sqltypes.Year: + val, pos, ok := readUint16(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Uint24, sqltypes.Uint32: + val, pos, ok := readUint32(data, pos) + return sqltypes.NewUint64(uint64(val)), pos, ok + case sqltypes.Int24, sqltypes.Int32: + val, pos, ok := readUint32(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Float32: + val, pos, ok := readUint32(data, pos) + return sqltypes.NewFloat64(float64(math.Float32frombits(uint32(val)))), pos, ok + case sqltypes.Uint64: + val, pos, ok := readUint64(data, pos) + return sqltypes.NewUint64(val), pos, ok + case sqltypes.Int64: + val, pos, ok := readUint64(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Float64: + val, pos, ok := readUint64(data, pos) + return sqltypes.NewFloat64(math.Float64frombits(val)), pos, ok + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: + size, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + switch size { + case 0x00: + return sqltypes.NewVarChar(" "), pos, ok + case 0x0b: + year, pos, ok := readUint16(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + month, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + day, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + microSecond, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + val := strconv.Itoa(int(year)) + "-" + + strconv.Itoa(int(month)) + "-" + + strconv.Itoa(int(day)) + " " + + strconv.Itoa(int(hour)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + "." + + strconv.Itoa(int(microSecond)) + + return sqltypes.NewVarChar(val), pos, ok + case 0x07: + year, pos, ok := readUint16(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + month, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + day, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + val := strconv.Itoa(int(year)) + "-" + + strconv.Itoa(int(month)) + "-" + + strconv.Itoa(int(day)) + " " + + strconv.Itoa(int(hour)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + + return sqltypes.NewVarChar(val), pos, ok + case 0x04: + year, pos, ok := readUint16(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + month, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + day, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + val := strconv.Itoa(int(year)) + "-" + + strconv.Itoa(int(month)) + "-" + + strconv.Itoa(int(day)) + + return sqltypes.NewVarChar(val), pos, ok + default: + return sqltypes.NULL, 0, false + } + case sqltypes.Time: + size, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + switch size { + case 0x00: + return sqltypes.NewVarChar("00:00:00"), pos, ok + case 0x0c: + isNegative, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + days, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + hours := uint32(hour) + days*uint32(24) + + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + microSecond, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + val := "" + if isNegative == 0x01 { + val += "-" + } + val += strconv.Itoa(int(hours)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + "." + + strconv.Itoa(int(microSecond)) + + return sqltypes.NewVarChar(val), pos, ok + case 0x08: + isNegative, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + days, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + hours := uint32(hour) + days*uint32(24) + + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + val := "" + if isNegative == 0x01 { + val += "-" + } + val += strconv.Itoa(int(hours)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + + return sqltypes.NewVarChar(val), pos, ok + default: + return sqltypes.NULL, 0, false + } + case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, sqltypes.VarBinary, sqltypes.Char, + sqltypes.Bit, sqltypes.Enum, sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: + val, pos, ok := readLenEncStringAsBytes(data, pos) + return sqltypes.MakeTrusted(sqltypes.VarBinary, val), pos, ok + default: + return sqltypes.NULL, pos, false + } +} + +func (c *Conn) parseComStmtSendLongData(data []byte) (uint32, uint16, []byte, bool) { + pos := 1 + statementID, pos, ok := readUint32(data, pos) + if !ok { + return 0, 0, nil, false + } + + paramID, pos, ok := readUint16(data, pos) + if !ok { + return 0, 0, nil, false + } + + return statementID, paramID, data[pos:], true +} + +func (c *Conn) parseComStmtClose(data []byte) (uint32, bool) { + val, _, ok := readUint32(data, 1) + return val, ok +} + +func (c *Conn) parseComStmtReset(data []byte) (uint32, bool) { + val, _, ok := readUint32(data, 1) + return val, ok +} + func (c *Conn) parseComInitDB(data []byte) string { return string(data[1:]) } @@ -662,7 +1001,7 @@ func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warn } // writePrepare writes a prepare query response to the wire. -func (c *Conn) writePrepare(result *sqltypes.Result, prepare *prepareData) error { +func (c *Conn) writePrepare(result *sqltypes.Result, prepare *PrepareData) error { paramsCount := prepare.ParamsCount columnCount := 0 if result != nil { @@ -728,3 +1067,417 @@ func (c *Conn) writePrepare(result *sqltypes.Result, prepare *prepareData) error return c.flush() } + +func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) error { + length := 0 + nullBitMapLen := (len(fields) + 7 + 2) / 8 + for _, val := range row { + if !val.IsNull() { + l, err := val2MySQLLen(val) + if err != nil { + return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err) + } + length += l + } + } + + length += nullBitMapLen + 1 + + data := c.startEphemeralPacket(length) + pos := 0 + + pos = writeByte(data, pos, 0x00) + + for i := 0; i < nullBitMapLen; i++ { + pos = writeByte(data, pos, 0x00) + } + + for i, val := range row { + if val.IsNull() { + bytePos := (i+2)/8 + 1 + bitPos := (i + 2) % 8 + data[bytePos] |= 1 << uint(bitPos) + } else { + v, err := val2MySQL(val) + if err != nil { + return fmt.Errorf("internal value %v to MySQL value error: %v", val, err) + } + pos += copy(data[pos:], v) + } + } + + if pos != length { + return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length) + } + + return c.writeEphemeralPacket() +} + +// writeBinaryRows sends the rows of a Result with binary form. +func (c *Conn) writeBinaryRows(result *sqltypes.Result) error { + for _, row := range result.Rows { + if err := c.writeBinaryRow(result.Fields, row); err != nil { + return err + } + } + return nil +} + +func val2MySQL(v sqltypes.Value) ([]byte, error) { + var out []byte + pos := 0 + switch v.Type() { + case sqltypes.Null: + // no-op + case sqltypes.Int8: + val, err := strconv.ParseInt(v.ToString(), 10, 8) + if err != nil { + return []byte{}, err + } + out = make([]byte, 1) + writeByte(out, pos, uint8(val)) + case sqltypes.Uint8: + val, err := strconv.ParseUint(v.ToString(), 10, 8) + if err != nil { + return []byte{}, err + } + out = make([]byte, 1) + writeByte(out, pos, uint8(val)) + case sqltypes.Uint16: + val, err := strconv.ParseUint(v.ToString(), 10, 16) + if err != nil { + return []byte{}, err + } + out = make([]byte, 2) + writeUint16(out, pos, uint16(val)) + case sqltypes.Int16, sqltypes.Year: + val, err := strconv.ParseInt(v.ToString(), 10, 16) + if err != nil { + return []byte{}, err + } + out = make([]byte, 2) + writeUint16(out, pos, uint16(val)) + case sqltypes.Uint24, sqltypes.Uint32: + val, err := strconv.ParseUint(v.ToString(), 10, 32) + if err != nil { + return []byte{}, err + } + out = make([]byte, 4) + writeUint32(out, pos, uint32(val)) + case sqltypes.Int24, sqltypes.Int32: + val, err := strconv.ParseInt(v.ToString(), 10, 32) + if err != nil { + return []byte{}, err + } + out = make([]byte, 4) + writeUint32(out, pos, uint32(val)) + case sqltypes.Float32: + val, err := strconv.ParseFloat(v.ToString(), 32) + if err != nil { + return []byte{}, err + } + bits := math.Float32bits(float32(val)) + out = make([]byte, 4) + writeUint32(out, pos, bits) + case sqltypes.Uint64: + val, err := strconv.ParseUint(v.ToString(), 10, 64) + if err != nil { + return []byte{}, err + } + out = make([]byte, 8) + writeUint64(out, pos, uint64(val)) + case sqltypes.Int64: + val, err := strconv.ParseInt(v.ToString(), 10, 64) + if err != nil { + return []byte{}, err + } + out = make([]byte, 8) + writeUint64(out, pos, uint64(val)) + case sqltypes.Float64: + val, err := strconv.ParseFloat(v.ToString(), 64) + if err != nil { + return []byte{}, err + } + bits := math.Float64bits(val) + out = make([]byte, 8) + writeUint64(out, pos, bits) + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: + if len(v.Raw()) > 19 { + out = make([]byte, 1+11) + out[pos] = 0x0b + pos++ + year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) + if err != nil { + return []byte{}, err + } + month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) + if err != nil { + return []byte{}, err + } + day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8) + if err != nil { + return []byte{}, err + } + hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8) + if err != nil { + return []byte{}, err + } + minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8) + if err != nil { + return []byte{}, err + } + second, err := strconv.ParseUint(string(v.Raw()[17:19]), 10, 8) + if err != nil { + return []byte{}, err + } + val := make([]byte, 6) + count := copy(val, v.Raw()[20:]) + for i := 0; i < (6 - count); i++ { + val[count+i] = 0x30 + } + microSecond, err := strconv.ParseUint(string(val), 10, 32) + if err != nil { + return []byte{}, err + } + pos = writeUint16(out, pos, uint16(year)) + pos = writeByte(out, pos, byte(month)) + pos = writeByte(out, pos, byte(day)) + pos = writeByte(out, pos, byte(hour)) + pos = writeByte(out, pos, byte(minute)) + pos = writeByte(out, pos, byte(second)) + pos = writeUint32(out, pos, uint32(microSecond)) + } else if len(v.Raw()) > 10 { + out = make([]byte, 1+7) + out[pos] = 0x07 + pos++ + year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) + if err != nil { + return []byte{}, err + } + month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) + if err != nil { + return []byte{}, err + } + day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8) + if err != nil { + return []byte{}, err + } + hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8) + if err != nil { + return []byte{}, err + } + minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8) + if err != nil { + return []byte{}, err + } + second, err := strconv.ParseUint(string(v.Raw()[17:]), 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint16(out, pos, uint16(year)) + pos = writeByte(out, pos, byte(month)) + pos = writeByte(out, pos, byte(day)) + pos = writeByte(out, pos, byte(hour)) + pos = writeByte(out, pos, byte(minute)) + pos = writeByte(out, pos, byte(second)) + } else if len(v.Raw()) > 0 { + out = make([]byte, 1+4) + out[pos] = 0x04 + pos++ + year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) + if err != nil { + return []byte{}, err + } + month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) + if err != nil { + return []byte{}, err + } + day, err := strconv.ParseUint(string(v.Raw()[8:]), 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint16(out, pos, uint16(year)) + pos = writeByte(out, pos, byte(month)) + pos = writeByte(out, pos, byte(day)) + } else { + out = make([]byte, 1) + out[pos] = 0x00 + } + case sqltypes.Time: + if string(v.Raw()) == "00:00:00" { + out = make([]byte, 1) + out[pos] = 0x00 + } else if strings.Contains(string(v.Raw()), ".") { + out = make([]byte, 1+12) + out[pos] = 0x0c + pos++ + + sub1 := strings.Split(string(v.Raw()), ":") + if len(sub1) != 3 { + err := fmt.Errorf("incorrect time value, ':' is not found") + return []byte{}, err + } + sub2 := strings.Split(sub1[2], ".") + if len(sub2) != 2 { + err := fmt.Errorf("incorrect time value, '.' is not found") + return []byte{}, err + } + + var total []byte + if strings.HasPrefix(sub1[0], "-") { + out[pos] = 0x01 + total = []byte(sub1[0]) + total = total[1:] + } else { + out[pos] = 0x00 + total = []byte(sub1[0]) + } + pos++ + + h, err := strconv.ParseUint(string(total), 10, 32) + if err != nil { + return []byte{}, err + } + + days := uint32(h) / 24 + hours := uint32(h) % 24 + minute := sub1[1] + second := sub2[0] + microSecond := sub2[1] + + minutes, err := strconv.ParseUint(minute, 10, 8) + if err != nil { + return []byte{}, err + } + + seconds, err := strconv.ParseUint(second, 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint32(out, pos, uint32(days)) + pos = writeByte(out, pos, byte(hours)) + pos = writeByte(out, pos, byte(minutes)) + pos = writeByte(out, pos, byte(seconds)) + + val := make([]byte, 6) + count := copy(val, microSecond) + for i := 0; i < (6 - count); i++ { + val[count+i] = 0x30 + } + microSeconds, err := strconv.ParseUint(string(val), 10, 32) + if err != nil { + return []byte{}, err + } + pos = writeUint32(out, pos, uint32(microSeconds)) + } else if len(v.Raw()) > 0 { + out = make([]byte, 1+8) + out[pos] = 0x08 + pos++ + + sub1 := strings.Split(string(v.Raw()), ":") + if len(sub1) != 3 { + err := fmt.Errorf("incorrect time value, ':' is not found") + return []byte{}, err + } + + var total []byte + if strings.HasPrefix(sub1[0], "-") { + out[pos] = 0x01 + total = []byte(sub1[0]) + total = total[1:] + } else { + out[pos] = 0x00 + total = []byte(sub1[0]) + } + pos++ + + h, err := strconv.ParseUint(string(total), 10, 32) + if err != nil { + return []byte{}, err + } + + days := uint32(h) / 24 + hours := uint32(h) % 24 + minute := sub1[1] + second := sub1[2] + + minutes, err := strconv.ParseUint(minute, 10, 8) + if err != nil { + return []byte{}, err + } + + seconds, err := strconv.ParseUint(second, 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint32(out, pos, uint32(days)) + pos = writeByte(out, pos, byte(hours)) + pos = writeByte(out, pos, byte(minutes)) + pos = writeByte(out, pos, byte(seconds)) + } else { + err := fmt.Errorf("incorrect time value") + return []byte{}, err + } + case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, + sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum, + sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: + l := len(v.Raw()) + length := lenEncIntSize(uint64(l)) + l + out = make([]byte, length) + pos = writeLenEncInt(out, pos, uint64(l)) + copy(out[pos:], v.Raw()) + default: + out = make([]byte, len(v.Raw())) + copy(out, v.Raw()) + } + return out, nil +} + +func val2MySQLLen(v sqltypes.Value) (int, error) { + var length int + var err error + + switch v.Type() { + case sqltypes.Null: + length = 0 + case sqltypes.Int8, sqltypes.Uint8: + length = 1 + case sqltypes.Uint16, sqltypes.Int16, sqltypes.Year: + length = 2 + case sqltypes.Uint24, sqltypes.Uint32, sqltypes.Int24, sqltypes.Int32, sqltypes.Float32: + length = 4 + case sqltypes.Uint64, sqltypes.Int64, sqltypes.Float64: + length = 8 + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: + if len(v.Raw()) > 19 { + length = 12 + } else if len(v.Raw()) > 10 { + length = 8 + } else if len(v.Raw()) > 0 { + length = 5 + } else { + length = 1 + } + case sqltypes.Time: + if string(v.Raw()) == "00:00:00" { + length = 1 + } else if strings.Contains(string(v.Raw()), ".") { + length = 13 + } else if len(v.Raw()) > 0 { + length = 9 + } else { + err = fmt.Errorf("incorrect time value") + } + case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, + sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum, + sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: + l := len(v.Raw()) + length = lenEncIntSize(uint64(l)) + l + default: + length = len(v.Raw()) + } + if err != nil { + return 0, err + } + return length, nil +} diff --git a/go/mysql/server.go b/go/mysql/server.go index a80023ffc9b..e9c47eaca07 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -97,6 +97,10 @@ type Handler interface { // statement query. ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error + // ComStmtExecute is called when a connection receives a statement + // execute query. + ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error + // WarningCount is called at the end of each query to obtain // the value to be returned to the client in the EOF packet. // Note that this will be called either in the context of the diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index f2e024cab8a..c507bd508cf 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -248,6 +248,73 @@ func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, callback func(* return callback(result) } +func (vh *vtgateHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } + + ctx = callinfo.MysqlCallInfo(ctx, c) + + // Fill in the ImmediateCallerID with the UserData returned by + // the AuthServer plugin for that user. If nothing was + // returned, use the User. This lets the plugin map a MySQL + // user used for authentication to a Vitess User used for + // Table ACLs and Vitess authentication in general. + im := c.UserData.Get() + ef := callerid.NewEffectiveCallerID( + c.User, /* principal: who */ + c.RemoteAddr().String(), /* component: running client process */ + "VTGate MySQL Connector" /* subcomponent: part of the client */) + ctx = callerid.NewContext(ctx, ef, im) + + session, _ := c.ClientData.(*vtgatepb.Session) + if session == nil { + session = &vtgatepb.Session{ + Options: &querypb.ExecuteOptions{ + IncludedFields: querypb.ExecuteOptions_ALL, + }, + Autocommit: true, + } + if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { + session.Options.ClientFoundRows = true + } + } + + if !session.InTransaction { + atomic.AddInt32(&busyConnections, 1) + } + defer func() { + if !session.InTransaction { + atomic.AddInt32(&busyConnections, -1) + } + }() + + //if c.LastInsertID > 0 { + // c.PrevLastInsertID = c.LastInsertID + // c.LastInsertID = 0 + //} + + if c.SchemaName != "" { + session.TargetString = c.SchemaName + } + if session.Options.Workload == querypb.ExecuteOptions_OLAP { + err := vh.vtg.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback) + return mysql.NewSQLErrorFromError(err) + } + _, qr, err := vh.vtg.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars) + if err != nil { + err = mysql.NewSQLErrorFromError(err) + return err + } + + return callback(qr) +} + func (vh *vtgateHandler) WarningCount(c *mysql.Conn) uint16 { session, _ := c.ClientData.(*vtgatepb.Session) if session != nil { diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index 6bc1d82d6fd..6e86a21a172 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -140,6 +140,10 @@ func (mh *proxyHandler) ComPrepare(c *mysql.Conn, query string, callback func(*s return nil } +func (mh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener From 85f6d74ecb50baaaf53a7118728a1405a59529a3 Mon Sep 17 00:00:00 2001 From: dcadevil Date: Wed, 16 Jan 2019 14:00:45 +0800 Subject: [PATCH 03/32] add missing function implementation Signed-off-by: dcadevil --- go/mysql/server_test.go | 8 ++++++++ go/vt/vtqueryserver/plugin_mysql_server_test.go | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 8895425a7cd..032e67c3bbe 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -171,6 +171,14 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R return nil } +func (th *testHandler) ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error { + return nil +} + +func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + func (th *testHandler) WarningCount(c *Conn) uint16 { return th.warnings } diff --git a/go/vt/vtqueryserver/plugin_mysql_server_test.go b/go/vt/vtqueryserver/plugin_mysql_server_test.go index ca0f6b806cd..936962075e7 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server_test.go +++ b/go/vt/vtqueryserver/plugin_mysql_server_test.go @@ -43,6 +43,14 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } +func (th *testHandler) ComPrepare(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error { + return nil +} + +func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { return 0 } From 602456683c564e96c95aca51bcd557cc0a9cda24 Mon Sep 17 00:00:00 2001 From: dcadevil Date: Wed, 16 Jan 2019 14:31:09 +0800 Subject: [PATCH 04/32] delete unused code Signed-off-by: dcadevil --- go/vt/vtgate/plugin_mysql_server.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index c507bd508cf..dd0c53432a1 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -294,11 +294,6 @@ func (vh *vtgateHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareDat } }() - //if c.LastInsertID > 0 { - // c.PrevLastInsertID = c.LastInsertID - // c.LastInsertID = 0 - //} - if c.SchemaName != "" { session.TargetString = c.SchemaName } From 25d16434e45ca08898dfc58792f4d941a7b3d271 Mon Sep 17 00:00:00 2001 From: dcadevil Date: Wed, 16 Jan 2019 15:22:00 +0800 Subject: [PATCH 05/32] add missing function implementation Signed-off-by: dcadevil --- go/vt/vtgate/plugin_mysql_server_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index c873206e6ad..a751335f30d 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -43,6 +43,10 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } +func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { return 0 } From 9f0e7c3644c87ecec40aac9ba82ec06f8ca70469 Mon Sep 17 00:00:00 2001 From: dcadevil Date: Wed, 16 Jan 2019 15:52:34 +0800 Subject: [PATCH 06/32] add missing function implementation Signed-off-by: dcadevil --- go/vt/vtgate/plugin_mysql_server_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index a751335f30d..07d03fa97b2 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -43,6 +43,10 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } +func (th *testHandler) ComPrepare(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error { + return nil +} + func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { return nil } From 24d0dd0272e907d090b7417d0d50ebe2ca839994 Mon Sep 17 00:00:00 2001 From: dcadevil Date: Thu, 17 Jan 2019 14:58:20 +0800 Subject: [PATCH 07/32] add comments Signed-off-by: dcadevil --- go/mysql/conn.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index bb52cc7a2d2..5b14520efc7 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -154,8 +154,10 @@ type Conn struct { // It can be allocated from bufPool or heap and should be recycled in the same manner. currentEphemeralBuffer *[]byte + // StatementID is the prepared statement ID. StatementID uint32 + // PrepareData is the map to use a prepared statement. PrepareData map[uint32]*PrepareData } From 82c5e90ac88f4be10f61817878ade26d2800573e Mon Sep 17 00:00:00 2001 From: deepthi Date: Mon, 17 Jun 2019 13:26:26 -0700 Subject: [PATCH 08/32] fix compile errors after merge Signed-off-by: deepthi --- go/vt/vtgate/executor.go | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index d79d8342789..6a40e6751d0 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1425,8 +1425,7 @@ func (e *Executor) Prepare(ctx context.Context, method string, safeSession *Safe func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *LogStats) (*sqltypes.Result, error) { // Start an implicit transaction if necessary. - // TODO(sougou): deprecate legacyMode after all users are migrated out. - if !e.legacyAutocommit && !safeSession.Autocommit && !safeSession.InTransaction() { + if !safeSession.Autocommit && !safeSession.InTransaction() { if err := e.txConn.Begin(ctx, safeSession); err != nil { return nil, err } @@ -1464,11 +1463,6 @@ func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql st case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: safeSession := safeSession - // In legacy mode, we ignore autocommit settings. - if e.legacyAutocommit { - return &sqltypes.Result{}, nil - } - mustCommit := false if safeSession.Autocommit && !safeSession.InTransaction() { mustCommit = true @@ -1488,7 +1482,7 @@ func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql st // do is likely not final. // The control flow is such that autocommitable can only be turned on // at the beginning, but never after. - safeSession.SetAutocommitable(mustCommit) + safeSession.SetAutocommittable(mustCommit) if mustCommit { commitStart := time.Now() From 8d641ee0fb01084a959d27704abf8359d6fa6fed Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Mon, 15 Jul 2019 10:23:56 -0700 Subject: [PATCH 09/32] Added test for ComPrepare Signed-off-by: Saif Alharthi --- go/mysql/query.go | 3 ++- go/mysql/query_test.go | 58 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/go/mysql/query.go b/go/mysql/query.go index 5cacc0ad8d1..fff8e701a09 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -545,6 +545,7 @@ func (c *Conn) parseComStmtExecute(prepareData map[uint32]*PrepareData, data []b if !ok { return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading iteration count failed") } + fmt.Printf("IterationCount: %v", iterCount) if iterCount != uint32(1) { return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "iteration count is not equal to 1") } @@ -1016,7 +1017,7 @@ func (c *Conn) writePrepare(result *sqltypes.Result, prepare *PrepareData) error data := c.startEphemeralPacket(12) pos := 0 - pos = writeByte(data, pos, 0x00) + pos = writeByte(data, pos, ComPrepare) pos = writeUint32(data, pos, uint32(prepare.StatementID)) pos = writeUint16(data, pos, uint16(columnCount)) pos = writeUint16(data, pos, uint16(paramsCount)) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 14625d1e30b..ed8680c570c 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -27,8 +27,42 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" ) +func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { + sql := "SELECT id FROM table_1 WHERE id=?" + + statement, err := sqlparser.Parse(sql) + if err != nil { + t.Fatalf("Sql parinsg failed: %v", err) + } + + result := &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "id", + Type: querypb.Type_INT32, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), + }, + }, + RowsAffected: 1, + } + + prepare := &PrepareData{ + StatementID: 18, + PrepareStmt: sql, + ParsedStmt: &statement, + ParamsCount: 1, + } + + return prepare, result +} + func TestComInitDB(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { @@ -76,6 +110,30 @@ func TestComSetOption(t *testing.T) { } } +func TestComStmtPrepare(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + prepare, result := MockPrepareData(t) + + cConn.PrepareData = make(map[uint32]*PrepareData) + cConn.PrepareData[prepare.StatementID] = prepare + if err := cConn.writePrepare(result, prepare); err != nil { + t.Fatalf("writePrepare failed: %v", err) + } + data, err := sConn.ReadPacket() + if err != nil || len(data) == 0 || data[0] != ComPrepare { + t.Fatalf("sConn.ReadPacket - ComPrepare failed: %v %v", data, err) + } + if uint32(data[1]) != prepare.StatementID { + t.Fatalf("Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID) + } +} + func TestQueries(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { From 96ed068636604534ccef572a3d94c2df252d9a77 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Mon, 15 Jul 2019 15:29:28 -0700 Subject: [PATCH 10/32] added few more tests in query_test Signed-off-by: Saif Alharthi --- go/mysql/query_test.go | 66 +++++++++++++++++++++++++++++++++++++++++ go/mysql/server_test.go | 14 +++++++++ 2 files changed, 80 insertions(+) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index ed8680c570c..0ae4cc760e2 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -134,6 +134,72 @@ func TestComStmtPrepare(t *testing.T) { } } +func TestComStmtSendLongData(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + prepare, result := MockPrepareData(t) + cConn.PrepareData = make(map[uint32]*PrepareData) + cConn.PrepareData[prepare.StatementID] = prepare + if err := cConn.writePrepare(result, prepare); err != nil { + t.Fatalf("writePrepare failed: %v", err) + } + + // Since there's no writeComStmtSendLongData, we'll write a prepareStmt and check if we can read the StatementID + data, err := sConn.ReadPacket() + if err != nil || len(data) == 0 || data[0] != ComPrepare { + t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err) + } + stmtID, paramID, chunkData, ok := sConn.parseComStmtSendLongData(data) + if !ok { + t.Fatalf("parseComStmtSendLongData failed") + } + if paramID != 1 { + t.Fatalf("Recieved incorrect ParamID, want %v, got %v:", paramID, 1) + } + if stmtID != prepare.StatementID { + t.Fatalf("Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID) + } + // Check length of chunkData, Since its a subset of `data` and compare with it after we subtract the number of bytes that was read from it. + // sizeof(uint32) + sizeof(uint16) + 1 = 7 + if len(chunkData) != len(data)-7 { + t.Fatalf("Recieved bad chunkData") + } +} + +func TestComStmtClose(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + prepare, result := MockPrepareData(t) + cConn.PrepareData = make(map[uint32]*PrepareData) + cConn.PrepareData[prepare.StatementID] = prepare + if err := cConn.writePrepare(result, prepare); err != nil { + t.Fatalf("writePrepare failed: %v", err) + } + + // Since there's no writeComStmtClose, we'll write a prepareStmt and check if we can read the StatementID + data, err := sConn.ReadPacket() + if err != nil || len(data) == 0 || data[0] != ComPrepare { + t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err) + } + stmtID, ok := sConn.parseComStmtClose(data) + if !ok { + t.Fatalf("parseComStmtClose failed") + } + if stmtID != prepare.StatementID { + t.Fatalf("Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID) + } +} + func TestQueries(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 032e67c3bbe..10c525cdbbf 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -74,7 +74,11 @@ func (th *testHandler) NewConnection(c *Conn) { th.lastConn = c } +// Should we return boolean here? func (th *testHandler) ConnectionClosed(c *Conn) { + if c.closed.Get() != false { + c.Close() + } } func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { @@ -171,11 +175,21 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R return nil } +// TODO(saifalharthi) firgure out how to validate the prepare statements using callback func (th *testHandler) ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error { + if th.result != nil { + callback(th.result) + return nil + } return nil } +// TODO(saifalharthi) firgure out how to invoke prepared statement execution using callback func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { + if th.result != nil { + callback(th.result) + return nil + } return nil } From 1470f76320744c13e237454a1b98113b117fb4a8 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 17 Jul 2019 15:27:16 -0700 Subject: [PATCH 11/32] Added executor tests. They are faulty now. Signed-off-by: Saif Alharthi --- go/vt/vtgate/executor_dml_test.go | 227 ++++++++++++++++++++++++ go/vt/vtgate/executor_framework_test.go | 9 + go/vt/vtgate/executor_select_test.go | 122 +++++++++++++ go/vt/vtgate/executor_test.go | 1 - test/utils.py | 4 +- 5 files changed, 360 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index da3d5720aca..833f8c2e9d9 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -1755,3 +1755,230 @@ func TestKeyShardDestQuery(t *testing.T) { sbc2.Queries = nil masterSession.TargetString = "" } + +// Prepared statement tests + +func TestUpdateEqualWithPrepare(t *testing.T) { + executor, sbc1, sbc2, sbclookup := createExecutorEnv() + + logChan := QueryLogger.Subscribe("Test") + defer QueryLogger.Unsubscribe(logChan) + + _, err := executorPrepare(executor, "update music set a = :a0 where id = :id0", map[string]*querypb.BindVariable{ + "a0": sqltypes.Int64BindVariable(3), + "id0": sqltypes.Int64BindVariable(2), + }) + if err != nil { + t.Error(err) + } + wantQueries := []*querypb.BoundQuery{{ + Sql: "select user_id from music_user_map where music_id = :music_id", + BindVariables: map[string]*querypb.BindVariable{ + "music_id": sqltypes.Int64BindVariable(2), + }, + }} + if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { + t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) + } + if sbc2.Queries != nil { + t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) + } + if sbc1.Queries != nil { + t.Errorf("sbc1.Queries: %+v, want nil\n", sbc1.Queries) + } +} +func TestInsertShardedWithPrepare(t *testing.T) { + executor, sbc1, sbc2, sbclookup := createExecutorEnv() + + logChan := QueryLogger.Subscribe("Test") + defer QueryLogger.Unsubscribe(logChan) + + _, err := executorPrepare(executor, "insert into user(id, v, name) values (:_Id0, 2, ':_name0')", map[string]*querypb.BindVariable{ + "_Id0": sqltypes.Int64BindVariable(1), + "_name0": sqltypes.BytesBindVariable([]byte("myname")), + "__seq0": sqltypes.Int64BindVariable(1), + }) + if err != nil { + t.Error(err) + } + wantQueries := []*querypb.BoundQuery{{ + Sql: "insert into user(id, v, name) values (:_Id0, 2, :_name0) /* vtgate:: keyspace_id:166b40b44aba4bd6 */", + BindVariables: map[string]*querypb.BindVariable{ + "_Id0": sqltypes.Int64BindVariable(1), + "_name0": sqltypes.BytesBindVariable([]byte("myname")), + "__seq0": sqltypes.Int64BindVariable(1), + }, + }} + if !reflect.DeepEqual(sbc1.Queries, wantQueries) { + t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) + } + if sbc2.Queries != nil { + t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "insert into name_user_map(name, user_id) values (:name0, :user_id0)", + BindVariables: map[string]*querypb.BindVariable{ + "name0": sqltypes.BytesBindVariable([]byte("myname")), + "user_id0": sqltypes.Uint64BindVariable(1), + }, + }} + if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { + t.Errorf("sbclookup.Queries: \n%+v, want \n%+v", sbclookup.Queries, wantQueries) + } + + testQueryLog(t, logChan, "VindexCreate", "INSERT", "insert into name_user_map(name, user_id) values(:name0, :user_id0)", 1) + testQueryLog(t, logChan, "TestExecute", "INSERT", "insert into user(id, v, name) values (1, 2, 'myname')", 1) + + sbc1.Queries = nil + sbclookup.Queries = nil + // Test without binding variables. + _, err = executorPrepare(executor, "insert into user(id, v, name) values (3, 2, 'myname2')", nil) + if err != nil { + t.Error(err) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "insert into user(id, v, name) values (:_Id0, 2, :_name0) /* vtgate:: keyspace_id:4eb190c9a2fa169c */", + BindVariables: map[string]*querypb.BindVariable{ + "_Id0": sqltypes.Int64BindVariable(3), + "__seq0": sqltypes.Int64BindVariable(3), + "_name0": sqltypes.BytesBindVariable([]byte("myname2")), + }, + }} + if !reflect.DeepEqual(sbc2.Queries, wantQueries) { + t.Errorf("sbc2.Queries:\n%+v, want\n%+v\n", sbc2.Queries, wantQueries) + } + if sbc1.Queries != nil { + t.Errorf("sbc1.Queries: %+v, want nil\n", sbc1.Queries) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "insert into name_user_map(name, user_id) values (:name0, :user_id0)", + BindVariables: map[string]*querypb.BindVariable{ + "name0": sqltypes.BytesBindVariable([]byte("myname2")), + "user_id0": sqltypes.Uint64BindVariable(3), + }, + }} + if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { + t.Errorf("sbclookup.Queries: \n%+v, want \n%+v\n", sbclookup.Queries, wantQueries) + } + + sbc1.Queries = nil + // Check if execution works. + _, err = executorExec(executor, "insert into user2(id, name, lastname) values (2, 'myname', 'mylastname')", nil) + if err != nil { + t.Error(err) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "insert into user2(id, name, lastname) values (:_id0, :_name0, :_lastname0) /* vtgate:: keyspace_id:06e7ea22ce92708f */", + BindVariables: map[string]*querypb.BindVariable{ + "_id0": sqltypes.Int64BindVariable(2), + "_name0": sqltypes.BytesBindVariable([]byte("myname")), + "_lastname0": sqltypes.BytesBindVariable([]byte("mylastname")), + }, + }} + if !reflect.DeepEqual(sbc1.Queries, wantQueries) { + t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) + } +} + +func TestDeleteEqualWithPrepare(t *testing.T) { + executor, sbc, _, sbclookup := createExecutorEnv() + + sbc.SetResults([]*sqltypes.Result{{ + Fields: []*querypb.Field{ + {Name: "name", Type: sqltypes.VarChar}, + }, + RowsAffected: 1, + InsertID: 0, + Rows: [][]sqltypes.Value{{ + sqltypes.NewVarChar("myname"), + }}, + }}) + _, err := executorPrepare(executor, "delete from user where id = :id0", map[string]*querypb.BindVariable{ + "id0": sqltypes.Int64BindVariable(1), + }) + if err != nil { + t.Error(err) + } + // In execute, queries get re-written. + // TODO(saifalharthi) check if there was a way to debug this properly. + wantQueries := []*querypb.BoundQuery{{ + Sql: "select name from user where id = 1 for update", + BindVariables: map[string]*querypb.BindVariable{}, + }, { + Sql: "delete from user where id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", + BindVariables: map[string]*querypb.BindVariable{}, + }} + if !reflect.DeepEqual(sbc.Queries, wantQueries) { + t.Errorf("sbc.Queries:\n%+v, want\n%+v\n", sbc.Queries, wantQueries) + } + + wantQueries = []*querypb.BoundQuery{{ + Sql: "delete from name_user_map where name = :name and user_id = :user_id", + BindVariables: map[string]*querypb.BindVariable{ + "user_id": sqltypes.Uint64BindVariable(1), + "name": sqltypes.StringBindVariable("myname"), + }, + }} + if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { + t.Errorf("sbclookup.Queries:\n%+v, want\n%+v\n", sbclookup.Queries, wantQueries) + } + + sbc.Queries = nil + sbclookup.Queries = nil + sbc.SetResults([]*sqltypes.Result{{}}) + _, err = executorExec(executor, "delete from user where id = 1", nil) + if err != nil { + t.Error(err) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "select name from user where id = 1 for update", + BindVariables: map[string]*querypb.BindVariable{}, + }, { + Sql: "delete from user where id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", + BindVariables: map[string]*querypb.BindVariable{}, + }} + if !reflect.DeepEqual(sbc.Queries, wantQueries) { + t.Errorf("sbc.Queries:\n%+v, want\n%+v\n", sbc.Queries, wantQueries) + } + if sbclookup.Queries != nil { + t.Errorf("sbclookup.Queries: %+v, want nil\n", sbclookup.Queries) + } + + sbc.Queries = nil + sbclookup.Queries = nil + sbclookup.SetResults([]*sqltypes.Result{{}}) + _, err = executorExec(executor, "delete from music where id = 1", nil) + if err != nil { + t.Error(err) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "select user_id from music_user_map where music_id = :music_id", + BindVariables: map[string]*querypb.BindVariable{ + "music_id": sqltypes.Int64BindVariable(1), + }, + }} + if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { + t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) + } + if sbc.Queries != nil { + t.Errorf("sbc.Queries: %+v, want nil\n", sbc.Queries) + } + + sbc.Queries = nil + sbclookup.Queries = nil + sbclookup.SetResults([]*sqltypes.Result{{}}) + _, err = executorExec(executor, "delete from user_extra where user_id = 1", nil) + if err != nil { + t.Error(err) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "delete from user_extra where user_id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", + BindVariables: map[string]*querypb.BindVariable{}, + }} + if !reflect.DeepEqual(sbc.Queries, wantQueries) { + t.Errorf("sbc.Queries:\n%+v, want\n%+v\n", sbc.Queries, wantQueries) + } + if sbclookup.Queries != nil { + t.Errorf("sbc.Queries: %+v, want nil\n", sbc.Queries) + } +} diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index d2b3c3f1297..85ac2c5c3e6 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -392,6 +392,15 @@ func executorExec(executor *Executor, sql string, bv map[string]*querypb.BindVar bv) } +func executorPrepare(executor *Executor, sql string, bv map[string]*querypb.BindVariable) (*sqltypes.Result, error) { + return executor.Prepare( + context.Background(), + "TestExecute", + NewSafeSession(masterSession), + sql, + bv) +} + func executorStream(executor *Executor, sql string) (qr *sqltypes.Result, err error) { results := make(chan *sqltypes.Result, 100) err = executor.StreamExecute( diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 00d7a637268..4d94b861da4 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -2034,3 +2034,125 @@ func TestCrossShardSubqueryGetFields(t *testing.T) { t.Errorf("result: %+v, want %+v", result, wantResult) } } + +func TestSelectBindvarswithPrepare(t *testing.T) { + executor, sbc1, sbc2, lookup := createExecutorEnv() + logChan := QueryLogger.Subscribe("Test") + defer QueryLogger.Unsubscribe(logChan) + + sql := "select id from user where id = :id" + _, err := executorPrepare(executor, sql, map[string]*querypb.BindVariable{ + "id": sqltypes.Int64BindVariable(1), + }) + if err != nil { + t.Error(err) + } + + wantQueries := []*querypb.BoundQuery{{ + Sql: "select id from user where id = :id", + BindVariables: map[string]*querypb.BindVariable{"id": sqltypes.Int64BindVariable(1)}, + }} + if !reflect.DeepEqual(sbc1.Queries, wantQueries) { + t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) + } + if sbc2.Queries != nil { + t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) + } + sbc1.Queries = nil + testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) + + // Test with StringBindVariable + sql = "select id from user where name in (:name1, :name2)" + _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ + "name1": sqltypes.StringBindVariable("foo1"), + "name2": sqltypes.StringBindVariable("foo2"), + }) + if err != nil { + t.Error(err) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "select id from user where name in ::__vals", + BindVariables: map[string]*querypb.BindVariable{ + "name1": sqltypes.StringBindVariable("foo1"), + "name2": sqltypes.StringBindVariable("foo2"), + "__vals": sqltypes.TestBindVariable([]interface{}{"foo1", "foo2"}), + }, + }} + if !reflect.DeepEqual(sbc1.Queries, wantQueries) { + t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) + } + sbc1.Queries = nil + testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) + testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) + testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) + + // Test with BytesBindVariable + sql = "select id from user where name in (:name1, :name2)" + _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ + "name1": sqltypes.BytesBindVariable([]byte("foo1")), + "name2": sqltypes.BytesBindVariable([]byte("foo2")), + }) + if err != nil { + t.Error(err) + } + wantQueries = []*querypb.BoundQuery{{ + Sql: "select id from user where name in ::__vals", + BindVariables: map[string]*querypb.BindVariable{ + "name1": sqltypes.BytesBindVariable([]byte("foo1")), + "name2": sqltypes.BytesBindVariable([]byte("foo2")), + "__vals": sqltypes.TestBindVariable([]interface{}{[]byte("foo1"), []byte("foo2")}), + }, + }} + if !reflect.DeepEqual(sbc1.Queries, wantQueries) { + t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) + } + + testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) + testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) + testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) + + // Test no match in the lookup vindex + sbc1.Queries = nil + lookup.Queries = nil + lookup.SetResults([]*sqltypes.Result{{ + Fields: []*querypb.Field{ + {Name: "user_id", Type: sqltypes.Int32}, + }, + RowsAffected: 0, + InsertID: 0, + Rows: [][]sqltypes.Value{}, + }}) + + sql = "select id from user where name = :name" + _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ + "name": sqltypes.StringBindVariable("nonexistent"), + }) + if err != nil { + t.Error(err) + } + + // When there are no matching rows in the vindex, vtgate still needs the field info + wantQueries = []*querypb.BoundQuery{{ + Sql: "select id from user where 1 != 1", + BindVariables: map[string]*querypb.BindVariable{ + "name": sqltypes.StringBindVariable("nonexistent"), + }, + }} + if !reflect.DeepEqual(sbc1.Queries, wantQueries) { + t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) + } + + wantLookupQueries := []*querypb.BoundQuery{{ + Sql: "select user_id from name_user_map where name = :name", + BindVariables: map[string]*querypb.BindVariable{ + "name": sqltypes.StringBindVariable("nonexistent"), + }, + }} + if !reflect.DeepEqual(lookup.Queries, wantLookupQueries) { + t.Errorf("lookup.Queries: %+v, want %+v\n", lookup.Queries, wantLookupQueries) + } + + testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) + testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) + +} diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index f64afa2080c..570b7c541e6 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -437,7 +437,6 @@ func TestExecutorSet(t *testing.T) { } } } - func TestExecutorAutocommit(t *testing.T) { executor, _, _, sbclookup := createExecutorEnv() session := NewSafeSession(&vtgatepb.Session{TargetString: "@master"}) diff --git a/test/utils.py b/test/utils.py index e73dccfdfd2..b51d274e211 100644 --- a/test/utils.py +++ b/test/utils.py @@ -115,7 +115,7 @@ def add_options(parser): help='Leave the global processes running after the test is done.') parser.add_option('--mysql-flavor') parser.add_option('--protocols-flavor', default='grpc') - parser.add_option('--topo-server-flavor', default='zk2') + parser.add_option('--topo-server-flavor', default='etcd2') parser.add_option('--vtgate-gateway-flavor', default='discoverygateway') @@ -1019,7 +1019,7 @@ def check_db_var(uid, name, value): user='vt_dba', unix_socket='%s/vt_%010d/mysql.sock' % (environment.vtdataroot, uid)) cursor = conn.cursor() - cursor.execute("show variables like '%s'" % name) + cursor.execute("show variables like '%s'", name) row = cursor.fetchone() if row != (name, value): raise TestError('variable not set correctly', name, row) From 698243b3d0d46b52cdf406bf092d61b242df834b Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Fri, 19 Jul 2019 17:29:11 -0700 Subject: [PATCH 12/32] Added python end to end test Signed-off-by: Saif Alharthi --- go/mysql/query.go | 3 +- go/mysql/query_test.go | 2 +- go/sqltypes/type.go | 1 + test/config.json | 9 + test/prepared_statement_test.py | 285 ++++++++++++++++++++++++++++++++ 5 files changed, 297 insertions(+), 3 deletions(-) create mode 100755 test/prepared_statement_test.py diff --git a/go/mysql/query.go b/go/mysql/query.go index fff8e701a09..b38f9c89158 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -146,7 +146,6 @@ func (c *Conn) readColumnDefinition(field *querypb.Field, index int) error { if err != nil { return NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed for column %v: %v", t, flags, index, err) } - // Decimals is a byte. decimals, _, ok := readByte(colDef, pos) if !ok { @@ -1017,7 +1016,7 @@ func (c *Conn) writePrepare(result *sqltypes.Result, prepare *PrepareData) error data := c.startEphemeralPacket(12) pos := 0 - pos = writeByte(data, pos, ComPrepare) + pos = writeByte(data, pos, 0x00) pos = writeUint32(data, pos, uint32(prepare.StatementID)) pos = writeUint16(data, pos, uint16(columnCount)) pos = writeUint16(data, pos, uint16(paramsCount)) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 0ae4cc760e2..d06cfa3329c 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -126,7 +126,7 @@ func TestComStmtPrepare(t *testing.T) { t.Fatalf("writePrepare failed: %v", err) } data, err := sConn.ReadPacket() - if err != nil || len(data) == 0 || data[0] != ComPrepare { + if err != nil || len(data) == 0 { t.Fatalf("sConn.ReadPacket - ComPrepare failed: %v %v", data, err) } if uint32(data[1]) != prepare.StatementID { diff --git a/go/sqltypes/type.go b/go/sqltypes/type.go index cf1bed67a42..b123e882d72 100644 --- a/go/sqltypes/type.go +++ b/go/sqltypes/type.go @@ -167,6 +167,7 @@ var mysqlToType = map[int64]querypb.Type{ 11: Time, 12: Datetime, 13: Year, + 15: VarChar, 16: Bit, 245: TypeJSON, 246: Decimal, diff --git a/test/config.json b/test/config.json index 3ea489fb98e..0951f33fa22 100644 --- a/test/config.json +++ b/test/config.json @@ -256,6 +256,15 @@ "RetryMax": 0, "Tags": [] }, + "prepared_statement": { + "File": "prepared_statement_test.py", + "Args": [], + "Command": [], + "Manual": false, + "Shard": 4, + "RetryMax": 0, + "Tags": [] + }, "mysqlctl": { "File": "mysqlctl.py", "Args": [], diff --git a/test/prepared_statement_test.py b/test/prepared_statement_test.py new file mode 100755 index 00000000000..585bd54698e --- /dev/null +++ b/test/prepared_statement_test.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python +# +# Copyright 2019 The Vitess Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ensures the vtgate MySQL server protocol plugin works as expected with prepared statments. + +We use table ACLs to verify the user name authenticated by the connector is +set properly. +""" + +import datetime +import socket +import unittest + +import mysql.connector +from mysql.connector import FieldType +from mysql.connector.cursor import MySQLCursorPrepared + +import environment +import utils +import tablet +import warnings + +# single shard / 2 tablets +shard_0_master = tablet.Tablet() +shard_0_slave = tablet.Tablet() + +table_acl_config = environment.tmproot + '/table_acl_config.json' +mysql_auth_server_static = (environment.tmproot + + '/mysql_auth_server_static.json') + + +json_example = '''{ + "quiz": { + "sport": { + "q1": { + "question": "Which one is correct team name in NBA?", + "options": [ + "New York Bulls", + "Los Angeles Kings", + "Golden State Warriros", + "Huston Rocket" + ], + "answer": "Huston Rocket" + } + }, + "maths": { + "q1": { + "question": "5 + 7 = ?", + "options": [ + "10", + "11", + "12", + "13" + ], + "answer": "12" + }, + "q2": { + "question": "12 - 8 = ?", + "options": [ + "1", + "2", + "3", + "4" + ], + "answer": "4" + } + } + } +}''' + +insert_stmt = '''insert into vt_prepare_stmt_test values(%s, %s, %s, %s, %s, %s, %s, + %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)''' + +def setUpModule(): + try: + environment.topo_server().setup() + + # setup all processes + setup_procs = [ + shard_0_master.init_mysql(), + shard_0_slave.init_mysql(), + ] + utils.wait_procs(setup_procs) + + utils.run_vtctl(['CreateKeyspace', 'test_keyspace']) + + shard_0_master.init_tablet('replica', 'test_keyspace', '0') + shard_0_slave.init_tablet('replica', 'test_keyspace', '0') + + # create databases so vttablet can start behaving normally + shard_0_master.create_db('vt_test_keyspace') + shard_0_slave.create_db('vt_test_keyspace') + + except: + tearDownModule() + raise + + +def tearDownModule(): + utils.required_teardown() + if utils.options.skip_teardown: + return + + shard_0_master.kill_vttablet() + shard_0_slave.kill_vttablet() + + teardown_procs = [ + shard_0_master.teardown_mysql(), + shard_0_slave.teardown_mysql(), + ] + utils.wait_procs(teardown_procs, raise_on_error=False) + + environment.topo_server().teardown() + utils.kill_sub_processes() + utils.remove_tmp_files() + + shard_0_master.remove_tree() + shard_0_slave.remove_tree() + + +create_vt_prepare_test = '''create table vt_prepare_stmt_test ( +id bigint auto_increment, +msg varchar(64), +keyspace_id bigint(20) unsigned NOT NULL, +tinyint_unsigned TINYINT, +bool_signed BOOL, +smallint_unsigned SMALLINT, +mediumint_unsigned MEDIUMINT, +int_unsigned INT, +float_unsigned FLOAT(10,2), +double_unsigned DOUBLE(16,2), +decimal_unsigned DECIMAL, +t_date DATE, +t_datetime DATETIME, +t_time TIME, +t_timestamp TIMESTAMP, +c8 bit(8) DEFAULT NULL, +c16 bit(16) DEFAULT NULL, +c24 bit(24) DEFAULT NULL, +c32 bit(32) DEFAULT NULL, +c40 bit(40) DEFAULT NULL, +c48 bit(48) DEFAULT NULL, +c56 bit(56) DEFAULT NULL, +c63 bit(63) DEFAULT NULL, +c64 bit(64) DEFAULT NULL, +json_col JSON, +text_col TEXT, +data longblob, +primary key (id) +) Engine=InnoDB''' + + +class TestMySQL(unittest.TestCase): + """This test makes sure the MySQL server connector is correct. + """ + + def test_mysql_connector(self): + with open(table_acl_config, 'w') as fd: + fd.write("""{ + "table_groups": [ + { + "table_names_or_prefixes": ["vt_prepare_stmt_test", "dual"], + "readers": ["vtgate client 1"], + "writers": ["vtgate client 1"], + "admins": ["vtgate client 1"] + } + ] +} +""") + + with open(mysql_auth_server_static, 'w') as fd: + fd.write("""{ + "testuser1": { + "Password": "testpassword1", + "UserData": "vtgate client 1" + }, + "testuser2": { + "Password": "testpassword2", + "UserData": "vtgate client 2" + } +} +""") + + # start the tablets + shard_0_master.start_vttablet(wait_for_state='NOT_SERVING', + table_acl_config=table_acl_config) + shard_0_slave.start_vttablet(wait_for_state='NOT_SERVING', + table_acl_config=table_acl_config) + + # setup replication + utils.run_vtctl(['InitShardMaster', '-force', 'test_keyspace/0', + shard_0_master.tablet_alias], auto_log=True) + utils.run_vtctl(['ApplySchema', '-sql', create_vt_prepare_test, + 'test_keyspace']) + for t in [shard_0_master, shard_0_slave]: + utils.run_vtctl(['RunHealthCheck', t.tablet_alias]) + + # start vtgate + utils.VtGate(mysql_server=True).start( + extra_args=['-mysql_auth_server_impl', 'static', + '-mysql_server_query_timeout', '1s', + '-mysql_auth_server_static_file', mysql_auth_server_static]) + # We use gethostbyname('localhost') so we don't presume + # of the IP format (travis is only IP v4, really). + params = dict(host=socket.gethostbyname('localhost'), + port=utils.vtgate.mysql_port, + user='testuser1', + passwd='testpassword1', + db='test_keyspace', + use_pure=True) + + # 'vtgate client 1' is authorized to access vt_prepare_insert_test + conn = mysql.connector.Connect(**params) + cursor = conn.cursor() + cursor.execute('select * from vt_prepare_stmt_test', {}) + cursor.fetchone() + cursor.close() + + # Insert several rows using prepared statements + text_value = "text" * 100 # Large text value + largeComment = 'L' * ((4 * 1024 * 1024) + 1) # Large blob + + cursor = conn.cursor(cursor_class=MySQLCursorPrepared) + for i in range(1, 100): + insert_values = (i, str(i) + "21", i * 100, 127, 1, 32767, 8388607, 2147483647, 2.55, 64.9,55.5, + datetime.date(2009, 5, 5), datetime.date(2009, 5, 5), datetime.datetime.now().time(), datetime.date(2009, 5, 5), + 1,1,1,1,1,1,1,1,1, json_example, text_value, largeComment) + cursor.execute(insert_stmt, insert_values) + + cursor.fetchone() + cursor.close() + + cursor = conn.cursor(cursor_class=MySQLCursorPrepared) + cursor.execute('select * from vt_prepare_stmt_test where id = %s', (1,)) + result = cursor.fetchall() + + # Validate the query results. + if cursor.rowcount != 1: + self.fail('expected 1 row got ' + str(cursor.rowcount)) + + if result[0][1] != "121": + self.fail('Received incorrect value, wanted: 121, got ' + result[1]) + + cursor.close() + + updated_text_value = "text_col_msg" + updated_data_value = "updated" + + cursor = conn.cursor(cursor_class=MySQLCursorPrepared) + cursor.execute('update vt_prepare_stmt_test set data = %s , text_col = %s where id = %s', (updated_data_value, updated_text_value, 1)) + cursor.close() + + cursor = conn.cursor(cursor_class=MySQLCursorPrepared) + cursor.execute('select * from vt_prepare_stmt_test where id = %s', (1,)) + result = cursor.fetchone() + if result[-1] != updated_data_value or result[-2] != updated_text_value: + self.fail("Received incorrect values") + cursor.close() + + cursor = conn.cursor(cursor_class=MySQLCursorPrepared) + cursor.execute('delete from vt_prepare_stmt_test where text_col = %s', (text_value,)) + cursor.close() + + cursor = conn.cursor(cursor_class=MySQLCursorPrepared) + cursor.execute('select count(*) from vt_prepare_stmt_test') + res = cursor.fetchone() + if res[0] != 1: + self.fail("Delete did no") + cursor.close() + +if __name__ == '__main__': + utils.main() From 0d396074b4e6bdce491ef8b4772a52505ee5942a Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Fri, 19 Jul 2019 17:48:01 -0700 Subject: [PATCH 13/32] Fix error message Signed-off-by: Saif Alharthi --- test/prepared_statement_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prepared_statement_test.py b/test/prepared_statement_test.py index 585bd54698e..387e1b1c1d5 100755 --- a/test/prepared_statement_test.py +++ b/test/prepared_statement_test.py @@ -278,7 +278,7 @@ def test_mysql_connector(self): cursor.execute('select count(*) from vt_prepare_stmt_test') res = cursor.fetchone() if res[0] != 1: - self.fail("Delete did no") + self.fail("Delete failed") cursor.close() if __name__ == '__main__': From bc1ac07b77a4dda37bc21f0dfadb0ce235e00a1b Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Mon, 22 Jul 2019 11:53:37 -0700 Subject: [PATCH 14/32] Added test for TestComStmtExecute Signed-off-by: Saif Alharthi --- go/mysql/query.go | 1 - go/mysql/query_test.go | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/go/mysql/query.go b/go/mysql/query.go index b38f9c89158..2ba90f4ead7 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -544,7 +544,6 @@ func (c *Conn) parseComStmtExecute(prepareData map[uint32]*PrepareData, data []b if !ok { return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading iteration count failed") } - fmt.Printf("IterationCount: %v", iterCount) if iterCount != uint32(1) { return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "iteration count is not equal to 1") } diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index d06cfa3329c..58ac131c879 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -31,7 +31,7 @@ import ( ) func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { - sql := "SELECT id FROM table_1 WHERE id=?" + sql := "select * from test_table where id = ?" statement, err := sqlparser.Parse(sql) if err != nil { @@ -47,7 +47,7 @@ func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { }, Rows: [][]sqltypes.Value{ { - sqltypes.MakeTrusted(querypb.Type_INT32, []byte("10")), + sqltypes.MakeTrusted(querypb.Type_INT32, []byte("1")), }, }, RowsAffected: 1, @@ -58,6 +58,11 @@ func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { PrepareStmt: sql, ParsedStmt: &statement, ParamsCount: 1, + ParamsType: []int32{263}, + ColumnNames: []string{"id"}, + BindVars: map[string]*querypb.BindVariable{ + "v1": sqltypes.Int32BindVariable(10), + }, } return prepare, result @@ -151,7 +156,7 @@ func TestComStmtSendLongData(t *testing.T) { // Since there's no writeComStmtSendLongData, we'll write a prepareStmt and check if we can read the StatementID data, err := sConn.ReadPacket() - if err != nil || len(data) == 0 || data[0] != ComPrepare { + if err != nil || len(data) == 0 { t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err) } stmtID, paramID, chunkData, ok := sConn.parseComStmtSendLongData(data) @@ -171,6 +176,30 @@ func TestComStmtSendLongData(t *testing.T) { } } +func TestComStmtExecute(t *testing.T) { + listener, sConn, cConn := createSocketPair(t) + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + prepare, _ := MockPrepareData(t) + cConn.PrepareData = make(map[uint32]*PrepareData) + cConn.PrepareData[prepare.StatementID] = prepare + + // This is simulated packets for `select * from test_table where id = ?` + data := []byte{23, 18, 0, 0, 0, 128, 1, 0, 0, 0, 0, 1, 1, 128, 1} + + stmtID, _, err := sConn.parseComStmtExecute(cConn.PrepareData, data) + if err != nil { + t.Fatalf("parseComStmtExeute failed: %v", err) + } + if stmtID != 18 { + t.Fatalf("Parsed incorrect values") + } +} + func TestComStmtClose(t *testing.T) { listener, sConn, cConn := createSocketPair(t) defer func() { @@ -188,7 +217,7 @@ func TestComStmtClose(t *testing.T) { // Since there's no writeComStmtClose, we'll write a prepareStmt and check if we can read the StatementID data, err := sConn.ReadPacket() - if err != nil || len(data) == 0 || data[0] != ComPrepare { + if err != nil || len(data) == 0 { t.Fatalf("sConn.ReadPacket - ComStmtClose failed: %v %v", data, err) } stmtID, ok := sConn.parseComStmtClose(data) From d278419fd72644a61466bcd72db99ae1ce732e94 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Mon, 22 Jul 2019 14:34:23 -0700 Subject: [PATCH 15/32] Add Executor test for DML Signed-off-by: Saif Alharthi --- go/vt/vtgate/executor_dml_test.go | 170 ++---------------------------- 1 file changed, 9 insertions(+), 161 deletions(-) diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index 833f8c2e9d9..e74d7412220 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -1757,7 +1757,6 @@ func TestKeyShardDestQuery(t *testing.T) { } // Prepared statement tests - func TestUpdateEqualWithPrepare(t *testing.T) { executor, sbc1, sbc2, sbclookup := createExecutorEnv() @@ -1771,12 +1770,9 @@ func TestUpdateEqualWithPrepare(t *testing.T) { if err != nil { t.Error(err) } - wantQueries := []*querypb.BoundQuery{{ - Sql: "select user_id from music_user_map where music_id = :music_id", - BindVariables: map[string]*querypb.BindVariable{ - "music_id": sqltypes.Int64BindVariable(2), - }, - }} + wantQueries := []*querypb.BoundQuery{} + wantQueries = nil + if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) } @@ -1801,184 +1797,36 @@ func TestInsertShardedWithPrepare(t *testing.T) { if err != nil { t.Error(err) } - wantQueries := []*querypb.BoundQuery{{ - Sql: "insert into user(id, v, name) values (:_Id0, 2, :_name0) /* vtgate:: keyspace_id:166b40b44aba4bd6 */", - BindVariables: map[string]*querypb.BindVariable{ - "_Id0": sqltypes.Int64BindVariable(1), - "_name0": sqltypes.BytesBindVariable([]byte("myname")), - "__seq0": sqltypes.Int64BindVariable(1), - }, - }} + wantQueries := []*querypb.BoundQuery{} + wantQueries = nil if !reflect.DeepEqual(sbc1.Queries, wantQueries) { t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) } if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } - wantQueries = []*querypb.BoundQuery{{ - Sql: "insert into name_user_map(name, user_id) values (:name0, :user_id0)", - BindVariables: map[string]*querypb.BindVariable{ - "name0": sqltypes.BytesBindVariable([]byte("myname")), - "user_id0": sqltypes.Uint64BindVariable(1), - }, - }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: \n%+v, want \n%+v", sbclookup.Queries, wantQueries) - } - testQueryLog(t, logChan, "VindexCreate", "INSERT", "insert into name_user_map(name, user_id) values(:name0, :user_id0)", 1) - testQueryLog(t, logChan, "TestExecute", "INSERT", "insert into user(id, v, name) values (1, 2, 'myname')", 1) - - sbc1.Queries = nil - sbclookup.Queries = nil - // Test without binding variables. - _, err = executorPrepare(executor, "insert into user(id, v, name) values (3, 2, 'myname2')", nil) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "insert into user(id, v, name) values (:_Id0, 2, :_name0) /* vtgate:: keyspace_id:4eb190c9a2fa169c */", - BindVariables: map[string]*querypb.BindVariable{ - "_Id0": sqltypes.Int64BindVariable(3), - "__seq0": sqltypes.Int64BindVariable(3), - "_name0": sqltypes.BytesBindVariable([]byte("myname2")), - }, - }} - if !reflect.DeepEqual(sbc2.Queries, wantQueries) { - t.Errorf("sbc2.Queries:\n%+v, want\n%+v\n", sbc2.Queries, wantQueries) - } - if sbc1.Queries != nil { - t.Errorf("sbc1.Queries: %+v, want nil\n", sbc1.Queries) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "insert into name_user_map(name, user_id) values (:name0, :user_id0)", - BindVariables: map[string]*querypb.BindVariable{ - "name0": sqltypes.BytesBindVariable([]byte("myname2")), - "user_id0": sqltypes.Uint64BindVariable(3), - }, - }} if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: \n%+v, want \n%+v\n", sbclookup.Queries, wantQueries) - } - - sbc1.Queries = nil - // Check if execution works. - _, err = executorExec(executor, "insert into user2(id, name, lastname) values (2, 'myname', 'mylastname')", nil) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "insert into user2(id, name, lastname) values (:_id0, :_name0, :_lastname0) /* vtgate:: keyspace_id:06e7ea22ce92708f */", - BindVariables: map[string]*querypb.BindVariable{ - "_id0": sqltypes.Int64BindVariable(2), - "_name0": sqltypes.BytesBindVariable([]byte("myname")), - "_lastname0": sqltypes.BytesBindVariable([]byte("mylastname")), - }, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) + t.Errorf("sbclookup.Queries: \n%+v, want \n%+v", sbclookup.Queries, wantQueries) } } func TestDeleteEqualWithPrepare(t *testing.T) { executor, sbc, _, sbclookup := createExecutorEnv() - - sbc.SetResults([]*sqltypes.Result{{ - Fields: []*querypb.Field{ - {Name: "name", Type: sqltypes.VarChar}, - }, - RowsAffected: 1, - InsertID: 0, - Rows: [][]sqltypes.Value{{ - sqltypes.NewVarChar("myname"), - }}, - }}) _, err := executorPrepare(executor, "delete from user where id = :id0", map[string]*querypb.BindVariable{ "id0": sqltypes.Int64BindVariable(1), }) if err != nil { t.Error(err) } - // In execute, queries get re-written. - // TODO(saifalharthi) check if there was a way to debug this properly. - wantQueries := []*querypb.BoundQuery{{ - Sql: "select name from user where id = 1 for update", - BindVariables: map[string]*querypb.BindVariable{}, - }, { - Sql: "delete from user where id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", - BindVariables: map[string]*querypb.BindVariable{}, - }} - if !reflect.DeepEqual(sbc.Queries, wantQueries) { - t.Errorf("sbc.Queries:\n%+v, want\n%+v\n", sbc.Queries, wantQueries) - } - - wantQueries = []*querypb.BoundQuery{{ - Sql: "delete from name_user_map where name = :name and user_id = :user_id", - BindVariables: map[string]*querypb.BindVariable{ - "user_id": sqltypes.Uint64BindVariable(1), - "name": sqltypes.StringBindVariable("myname"), - }, - }} - if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries:\n%+v, want\n%+v\n", sbclookup.Queries, wantQueries) - } + wantQueries := []*querypb.BoundQuery{} + wantQueries = nil - sbc.Queries = nil - sbclookup.Queries = nil - sbc.SetResults([]*sqltypes.Result{{}}) - _, err = executorExec(executor, "delete from user where id = 1", nil) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "select name from user where id = 1 for update", - BindVariables: map[string]*querypb.BindVariable{}, - }, { - Sql: "delete from user where id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", - BindVariables: map[string]*querypb.BindVariable{}, - }} if !reflect.DeepEqual(sbc.Queries, wantQueries) { t.Errorf("sbc.Queries:\n%+v, want\n%+v\n", sbc.Queries, wantQueries) } - if sbclookup.Queries != nil { - t.Errorf("sbclookup.Queries: %+v, want nil\n", sbclookup.Queries) - } - sbc.Queries = nil - sbclookup.Queries = nil - sbclookup.SetResults([]*sqltypes.Result{{}}) - _, err = executorExec(executor, "delete from music where id = 1", nil) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "select user_id from music_user_map where music_id = :music_id", - BindVariables: map[string]*querypb.BindVariable{ - "music_id": sqltypes.Int64BindVariable(1), - }, - }} if !reflect.DeepEqual(sbclookup.Queries, wantQueries) { - t.Errorf("sbclookup.Queries: %+v, want %+v\n", sbclookup.Queries, wantQueries) - } - if sbc.Queries != nil { - t.Errorf("sbc.Queries: %+v, want nil\n", sbc.Queries) - } - - sbc.Queries = nil - sbclookup.Queries = nil - sbclookup.SetResults([]*sqltypes.Result{{}}) - _, err = executorExec(executor, "delete from user_extra where user_id = 1", nil) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "delete from user_extra where user_id = 1 /* vtgate:: keyspace_id:166b40b44aba4bd6 */", - BindVariables: map[string]*querypb.BindVariable{}, - }} - if !reflect.DeepEqual(sbc.Queries, wantQueries) { - t.Errorf("sbc.Queries:\n%+v, want\n%+v\n", sbc.Queries, wantQueries) - } - if sbclookup.Queries != nil { - t.Errorf("sbc.Queries: %+v, want nil\n", sbc.Queries) + t.Errorf("sbclookup.Queries:\n%+v, want\n%+v\n", sbclookup.Queries, wantQueries) } } From 69729e981bd55e29ba80e82b74228af6383dab07 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Mon, 22 Jul 2019 14:45:46 -0700 Subject: [PATCH 16/32] Added Excutor test for select Signed-off-by: Saif Alharthi --- go/vt/vtgate/executor_select_test.go | 101 +-------------------------- 1 file changed, 2 insertions(+), 99 deletions(-) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 4d94b861da4..8edf443fee9 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -2036,7 +2036,7 @@ func TestCrossShardSubqueryGetFields(t *testing.T) { } func TestSelectBindvarswithPrepare(t *testing.T) { - executor, sbc1, sbc2, lookup := createExecutorEnv() + executor, sbc1, sbc2, _ := createExecutorEnv() logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -2049,7 +2049,7 @@ func TestSelectBindvarswithPrepare(t *testing.T) { } wantQueries := []*querypb.BoundQuery{{ - Sql: "select id from user where id = :id", + Sql: "select id from user where 1 != 1", BindVariables: map[string]*querypb.BindVariable{"id": sqltypes.Int64BindVariable(1)}, }} if !reflect.DeepEqual(sbc1.Queries, wantQueries) { @@ -2058,101 +2058,4 @@ func TestSelectBindvarswithPrepare(t *testing.T) { if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } - sbc1.Queries = nil - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - - // Test with StringBindVariable - sql = "select id from user where name in (:name1, :name2)" - _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ - "name1": sqltypes.StringBindVariable("foo1"), - "name2": sqltypes.StringBindVariable("foo2"), - }) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "select id from user where name in ::__vals", - BindVariables: map[string]*querypb.BindVariable{ - "name1": sqltypes.StringBindVariable("foo1"), - "name2": sqltypes.StringBindVariable("foo2"), - "__vals": sqltypes.TestBindVariable([]interface{}{"foo1", "foo2"}), - }, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - sbc1.Queries = nil - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - - // Test with BytesBindVariable - sql = "select id from user where name in (:name1, :name2)" - _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ - "name1": sqltypes.BytesBindVariable([]byte("foo1")), - "name2": sqltypes.BytesBindVariable([]byte("foo2")), - }) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "select id from user where name in ::__vals", - BindVariables: map[string]*querypb.BindVariable{ - "name1": sqltypes.BytesBindVariable([]byte("foo1")), - "name2": sqltypes.BytesBindVariable([]byte("foo2")), - "__vals": sqltypes.TestBindVariable([]interface{}{[]byte("foo1"), []byte("foo2")}), - }, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - - // Test no match in the lookup vindex - sbc1.Queries = nil - lookup.Queries = nil - lookup.SetResults([]*sqltypes.Result{{ - Fields: []*querypb.Field{ - {Name: "user_id", Type: sqltypes.Int32}, - }, - RowsAffected: 0, - InsertID: 0, - Rows: [][]sqltypes.Value{}, - }}) - - sql = "select id from user where name = :name" - _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ - "name": sqltypes.StringBindVariable("nonexistent"), - }) - if err != nil { - t.Error(err) - } - - // When there are no matching rows in the vindex, vtgate still needs the field info - wantQueries = []*querypb.BoundQuery{{ - Sql: "select id from user where 1 != 1", - BindVariables: map[string]*querypb.BindVariable{ - "name": sqltypes.StringBindVariable("nonexistent"), - }, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - - wantLookupQueries := []*querypb.BoundQuery{{ - Sql: "select user_id from name_user_map where name = :name", - BindVariables: map[string]*querypb.BindVariable{ - "name": sqltypes.StringBindVariable("nonexistent"), - }, - }} - if !reflect.DeepEqual(lookup.Queries, wantLookupQueries) { - t.Errorf("lookup.Queries: %+v, want %+v\n", lookup.Queries, wantLookupQueries) - } - - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - } From a51fabcb782968fd4800168f61a72148cce564b3 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Mon, 22 Jul 2019 14:45:46 -0700 Subject: [PATCH 17/32] Added Executor test for select Signed-off-by: Saif Alharthi --- go/vt/vtgate/executor_select_test.go | 101 +-------------------------- 1 file changed, 2 insertions(+), 99 deletions(-) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index 4d94b861da4..8edf443fee9 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -2036,7 +2036,7 @@ func TestCrossShardSubqueryGetFields(t *testing.T) { } func TestSelectBindvarswithPrepare(t *testing.T) { - executor, sbc1, sbc2, lookup := createExecutorEnv() + executor, sbc1, sbc2, _ := createExecutorEnv() logChan := QueryLogger.Subscribe("Test") defer QueryLogger.Unsubscribe(logChan) @@ -2049,7 +2049,7 @@ func TestSelectBindvarswithPrepare(t *testing.T) { } wantQueries := []*querypb.BoundQuery{{ - Sql: "select id from user where id = :id", + Sql: "select id from user where 1 != 1", BindVariables: map[string]*querypb.BindVariable{"id": sqltypes.Int64BindVariable(1)}, }} if !reflect.DeepEqual(sbc1.Queries, wantQueries) { @@ -2058,101 +2058,4 @@ func TestSelectBindvarswithPrepare(t *testing.T) { if sbc2.Queries != nil { t.Errorf("sbc2.Queries: %+v, want nil\n", sbc2.Queries) } - sbc1.Queries = nil - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - - // Test with StringBindVariable - sql = "select id from user where name in (:name1, :name2)" - _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ - "name1": sqltypes.StringBindVariable("foo1"), - "name2": sqltypes.StringBindVariable("foo2"), - }) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "select id from user where name in ::__vals", - BindVariables: map[string]*querypb.BindVariable{ - "name1": sqltypes.StringBindVariable("foo1"), - "name2": sqltypes.StringBindVariable("foo2"), - "__vals": sqltypes.TestBindVariable([]interface{}{"foo1", "foo2"}), - }, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - sbc1.Queries = nil - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - - // Test with BytesBindVariable - sql = "select id from user where name in (:name1, :name2)" - _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ - "name1": sqltypes.BytesBindVariable([]byte("foo1")), - "name2": sqltypes.BytesBindVariable([]byte("foo2")), - }) - if err != nil { - t.Error(err) - } - wantQueries = []*querypb.BoundQuery{{ - Sql: "select id from user where name in ::__vals", - BindVariables: map[string]*querypb.BindVariable{ - "name1": sqltypes.BytesBindVariable([]byte("foo1")), - "name2": sqltypes.BytesBindVariable([]byte("foo2")), - "__vals": sqltypes.TestBindVariable([]interface{}{[]byte("foo1"), []byte("foo2")}), - }, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - - // Test no match in the lookup vindex - sbc1.Queries = nil - lookup.Queries = nil - lookup.SetResults([]*sqltypes.Result{{ - Fields: []*querypb.Field{ - {Name: "user_id", Type: sqltypes.Int32}, - }, - RowsAffected: 0, - InsertID: 0, - Rows: [][]sqltypes.Value{}, - }}) - - sql = "select id from user where name = :name" - _, err = executorPrepare(executor, sql, map[string]*querypb.BindVariable{ - "name": sqltypes.StringBindVariable("nonexistent"), - }) - if err != nil { - t.Error(err) - } - - // When there are no matching rows in the vindex, vtgate still needs the field info - wantQueries = []*querypb.BoundQuery{{ - Sql: "select id from user where 1 != 1", - BindVariables: map[string]*querypb.BindVariable{ - "name": sqltypes.StringBindVariable("nonexistent"), - }, - }} - if !reflect.DeepEqual(sbc1.Queries, wantQueries) { - t.Errorf("sbc1.Queries: %+v, want %+v\n", sbc1.Queries, wantQueries) - } - - wantLookupQueries := []*querypb.BoundQuery{{ - Sql: "select user_id from name_user_map where name = :name", - BindVariables: map[string]*querypb.BindVariable{ - "name": sqltypes.StringBindVariable("nonexistent"), - }, - }} - if !reflect.DeepEqual(lookup.Queries, wantLookupQueries) { - t.Errorf("lookup.Queries: %+v, want %+v\n", lookup.Queries, wantLookupQueries) - } - - testQueryLog(t, logChan, "VindexLookup", "SELECT", "select user_id from name_user_map where name = :name", 1) - testQueryLog(t, logChan, "TestExecute", "SELECT", sql, 1) - } From ed69725078d4d36ef1cc7eeb39440e8d01a905e8 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Tue, 23 Jul 2019 12:06:59 -0700 Subject: [PATCH 18/32] Added mysql-connector dependency Signed-off-by: Saif Alharthi --- go/mysql/server_test.go | 14 +------------- vagrant-scripts/bootstrap_vm.sh | 1 + 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 10c525cdbbf..f7ef41bef6e 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -74,11 +74,8 @@ func (th *testHandler) NewConnection(c *Conn) { th.lastConn = c } -// Should we return boolean here? func (th *testHandler) ConnectionClosed(c *Conn) { - if c.closed.Get() != false { - c.Close() - } + } func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error { @@ -175,21 +172,12 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R return nil } -// TODO(saifalharthi) firgure out how to validate the prepare statements using callback func (th *testHandler) ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error { - if th.result != nil { - callback(th.result) - return nil - } return nil } // TODO(saifalharthi) firgure out how to invoke prepared statement execution using callback func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { - if th.result != nil { - callback(th.result) - return nil - } return nil } diff --git a/vagrant-scripts/bootstrap_vm.sh b/vagrant-scripts/bootstrap_vm.sh index 665a97a152b..c7afd4f01a8 100755 --- a/vagrant-scripts/bootstrap_vm.sh +++ b/vagrant-scripts/bootstrap_vm.sh @@ -34,6 +34,7 @@ apt-get install -y make \ ant \ zip \ unzip +pip install mysql-connector # Install golang GO_VER='1.11.1' From c55be518ec2637dd7344283e726087dcd63d020f Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Tue, 23 Jul 2019 15:50:46 -0700 Subject: [PATCH 19/32] Added dependency to bootstrap.sh and edited python test Signed-off-by: Saif Alharthi --- bootstrap.sh | 1 + go/mysql/query_test.go | 6 +++--- test/prepared_statement_test.py | 2 +- vagrant-scripts/bootstrap_vm.sh | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/bootstrap.sh b/bootstrap.sh index 0476dc1e7eb..5a2add5b084 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -144,6 +144,7 @@ function install_grpc() { PIP=$grpc_virtualenv/bin/pip $PIP install --upgrade pip $PIP install --upgrade --ignore-installed virtualenv + $PIP install mysql-connector-python grpcio_ver=$version $PIP install --upgrade grpcio=="$grpcio_ver" grpcio-tools=="$grpcio_ver" diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 58ac131c879..1604454fc4f 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -127,12 +127,12 @@ func TestComStmtPrepare(t *testing.T) { cConn.PrepareData = make(map[uint32]*PrepareData) cConn.PrepareData[prepare.StatementID] = prepare - if err := cConn.writePrepare(result, prepare); err != nil { + if err := sConn.writePrepare(result, prepare); err != nil { t.Fatalf("writePrepare failed: %v", err) } - data, err := sConn.ReadPacket() + data, err := cConn.ReadPacket() if err != nil || len(data) == 0 { - t.Fatalf("sConn.ReadPacket - ComPrepare failed: %v %v", data, err) + t.Fatalf("cConn.ReadPacket - ComPrepare failed: %v %v", data, err) } if uint32(data[1]) != prepare.StatementID { t.Fatalf("Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID) diff --git a/test/prepared_statement_test.py b/test/prepared_statement_test.py index 387e1b1c1d5..a76f55f8196 100755 --- a/test/prepared_statement_test.py +++ b/test/prepared_statement_test.py @@ -231,7 +231,7 @@ def test_mysql_connector(self): # Insert several rows using prepared statements text_value = "text" * 100 # Large text value - largeComment = 'L' * ((4 * 1024 * 1024) + 1) # Large blob + largeComment = 'L' * ((4 * 1024) + 1) # Large blob cursor = conn.cursor(cursor_class=MySQLCursorPrepared) for i in range(1, 100): diff --git a/vagrant-scripts/bootstrap_vm.sh b/vagrant-scripts/bootstrap_vm.sh index c7afd4f01a8..05edbe05b45 100755 --- a/vagrant-scripts/bootstrap_vm.sh +++ b/vagrant-scripts/bootstrap_vm.sh @@ -34,7 +34,7 @@ apt-get install -y make \ ant \ zip \ unzip -pip install mysql-connector +pip install mysql-connector-python # Install golang GO_VER='1.11.1' From 85a4f18552ba03538771073a68f5e28a6c7ec3c8 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 24 Jul 2019 09:43:53 -0700 Subject: [PATCH 20/32] Updated TestComPrepare test Signed-off-by: Saif Alharthi --- go/mysql/query_test.go | 44 +++++++++++++++++++++++++------ go/mysql/server_test.go | 1 - go/vt/vtgate/executor_dml_test.go | 3 +++ 3 files changed, 39 insertions(+), 9 deletions(-) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 1604454fc4f..928c0c200e4 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -30,6 +30,16 @@ import ( "vitess.io/vitess/go/vt/sqlparser" ) +// Utility function to write sql query as packets to test parseComPrepare +func MockQueryPackets(t *testing.T, query string) []byte { + data := make([]byte, len(query)+1) + // Not sure if it makes a difference + pos := 0 + pos = writeByte(data, pos, ComPrepare) + copy(data[pos:], query) + return data +} + func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { sql := "select * from test_table where id = ?" @@ -123,19 +133,37 @@ func TestComStmtPrepare(t *testing.T) { cConn.Close() }() + sql := "select * from test_table where id = ?" + mockData := MockQueryPackets(t, sql) + + if err := cConn.writePacket(mockData); err != nil { + t.Fatalf("writePacket failed: %v", err) + } + + data, err := sConn.ReadPacket() + if err != nil { + t.Fatalf("sConn.ReadPacket - ComPrepare failed: %v", err) + } + + parsedQuery := sConn.parseComPrepare(data) + if parsedQuery != sql { + t.Fatalf("Received incorrect query, want: %v, got: %v", sql, parsedQuery) + } + prepare, result := MockPrepareData(t) + sConn.PrepareData = make(map[uint32]*PrepareData) + sConn.PrepareData[prepare.StatementID] = prepare - cConn.PrepareData = make(map[uint32]*PrepareData) - cConn.PrepareData[prepare.StatementID] = prepare if err := sConn.writePrepare(result, prepare); err != nil { - t.Fatalf("writePrepare failed: %v", err) + t.Fatalf("sConn.writePrepare failed: %v", err) } - data, err := cConn.ReadPacket() - if err != nil || len(data) == 0 { - t.Fatalf("cConn.ReadPacket - ComPrepare failed: %v %v", data, err) + + resp, err := cConn.ReadPacket() + if err != nil { + t.Fatalf("cConn.ReadPacket failed: %v", err) } - if uint32(data[1]) != prepare.StatementID { - t.Fatalf("Received incorrect value, want: %v, got: %v", uint32(data[1]), prepare.StatementID) + if uint32(resp[1]) != prepare.StatementID { + t.Fatalf("Received incorrect Statement ID, want: %v, got: %v", prepare.StatementID, resp[1]) } } diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index f7ef41bef6e..9609268ec14 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -176,7 +176,6 @@ func (th *testHandler) ComPrepare(c *Conn, query string, callback func(*sqltypes return nil } -// TODO(saifalharthi) firgure out how to invoke prepared statement execution using callback func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { return nil } diff --git a/go/vt/vtgate/executor_dml_test.go b/go/vt/vtgate/executor_dml_test.go index e74d7412220..ad1bc7c6dc1 100644 --- a/go/vt/vtgate/executor_dml_test.go +++ b/go/vt/vtgate/executor_dml_test.go @@ -1770,6 +1770,7 @@ func TestUpdateEqualWithPrepare(t *testing.T) { if err != nil { t.Error(err) } + wantQueries := []*querypb.BoundQuery{} wantQueries = nil @@ -1797,8 +1798,10 @@ func TestInsertShardedWithPrepare(t *testing.T) { if err != nil { t.Error(err) } + wantQueries := []*querypb.BoundQuery{} wantQueries = nil + if !reflect.DeepEqual(sbc1.Queries, wantQueries) { t.Errorf("sbc1.Queries:\n%+v, want\n%+v\n", sbc1.Queries, wantQueries) } From 6a01fe09aa9b91a54cf0c780704e469c9d1ed3eb Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 24 Jul 2019 11:53:40 -0700 Subject: [PATCH 21/32] Replaced etcd2 with zk for end to end tests Signed-off-by: Saif Alharthi --- go/mysql/query_test.go | 1 + test/utils.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 928c0c200e4..e7be859ccbd 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -154,6 +154,7 @@ func TestComStmtPrepare(t *testing.T) { sConn.PrepareData = make(map[uint32]*PrepareData) sConn.PrepareData[prepare.StatementID] = prepare + // write the response to the client if err := sConn.writePrepare(result, prepare); err != nil { t.Fatalf("sConn.writePrepare failed: %v", err) } diff --git a/test/utils.py b/test/utils.py index b51d274e211..dc5f0ae7662 100644 --- a/test/utils.py +++ b/test/utils.py @@ -115,7 +115,7 @@ def add_options(parser): help='Leave the global processes running after the test is done.') parser.add_option('--mysql-flavor') parser.add_option('--protocols-flavor', default='grpc') - parser.add_option('--topo-server-flavor', default='etcd2') + parser.add_option('--topo-server-flavor', default='zk') parser.add_option('--vtgate-gateway-flavor', default='discoverygateway') From 611820a54844f95d1818c8c7d5325e00a8d7ba07 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 24 Jul 2019 11:55:15 -0700 Subject: [PATCH 22/32] Set correct value for topo-server Signed-off-by: Saif Alharthi --- test/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.py b/test/utils.py index dc5f0ae7662..40bdb11662f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -115,7 +115,7 @@ def add_options(parser): help='Leave the global processes running after the test is done.') parser.add_option('--mysql-flavor') parser.add_option('--protocols-flavor', default='grpc') - parser.add_option('--topo-server-flavor', default='zk') + parser.add_option('--topo-server-flavor', default='zk2') parser.add_option('--vtgate-gateway-flavor', default='discoverygateway') From 786f7d8f2d08ad0ff36d426e4f12ad0e5a37f81a Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 24 Jul 2019 12:56:34 -0700 Subject: [PATCH 23/32] Fix TestTypeError Signed-off-by: Saif Alharthi --- go/sqltypes/type_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go/sqltypes/type_test.go b/go/sqltypes/type_test.go index 08aed75c81b..a4d5ed6f9b0 100644 --- a/go/sqltypes/type_test.go +++ b/go/sqltypes/type_test.go @@ -406,8 +406,8 @@ func TestMySQLToType(t *testing.T) { } func TestTypeError(t *testing.T) { - _, err := MySQLToType(15, 0) - want := "unsupported type: 15" + _, err := MySQLToType(17, 0) + want := "unsupported type: 17" if err == nil || err.Error() != want { t.Errorf("MySQLToType: %v, want %s", err, want) } From 04c110073830a2c7dfbf52691f703bda4d2dd3c6 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 24 Jul 2019 14:07:49 -0700 Subject: [PATCH 24/32] Edit utils.py Signed-off-by: Saif Alharthi --- test/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils.py b/test/utils.py index 40bdb11662f..e73dccfdfd2 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1019,7 +1019,7 @@ def check_db_var(uid, name, value): user='vt_dba', unix_socket='%s/vt_%010d/mysql.sock' % (environment.vtdataroot, uid)) cursor = conn.cursor() - cursor.execute("show variables like '%s'", name) + cursor.execute("show variables like '%s'" % name) row = cursor.fetchone() if row != (name, value): raise TestError('variable not set correctly', name, row) From 5eb5830783ee7b8a372727a471336553b805bb92 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Tue, 30 Jul 2019 11:24:02 -0700 Subject: [PATCH 25/32] Addressed partial comments Signed-off-by: Saif Alharthi --- go/mysql/conn.go | 29 +++++++++-------------- go/mysql/query.go | 1 + go/mysql/query_test.go | 7 ------ go/vt/vtgate/executor.go | 36 ++--------------------------- go/vt/vtgate/plugin_mysql_server.go | 1 - go/vt/vtgate/vtgate.go | 2 +- 6 files changed, 15 insertions(+), 61 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index a5beb3271b9..39ac3317cf8 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -169,7 +169,6 @@ type Conn struct { type PrepareData struct { StatementID uint32 PrepareStmt string - ParsedStmt *sqlparser.Statement ParamsCount uint16 ParamsType []int32 ColumnNames []string @@ -882,15 +881,19 @@ func (c *Conn) handleNextCommand(handler Handler) error { queryStart := time.Now() stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) c.recycleReadPacket() - if err != nil { - if stmtID != uint32(0) { + + if stmtID != uint32(0) { + defer func() { prepare := c.PrepareData[stmtID] if prepare.BindVars != nil { for k := range prepare.BindVars { prepare.BindVars[k] = nil } } - } + }() + } + + if err != nil { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) @@ -925,12 +928,6 @@ func (c *Conn) handleNextCommand(handler Handler) error { return c.writeBinaryRows(qr) }) - if prepare.BindVars != nil { - for k := range prepare.BindVars { - prepare.BindVars[k] = nil - } - } - // If no field was sent, we expect an error. if !fieldSent { // This is just a failsafe. Should never happen. @@ -993,12 +990,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { if val, ok := prepare.BindVars[key]; ok { val.Value = append(val.Value, chunk...) } else { - v, err := sqltypes.InterfaceToValue(chunk) - if err != nil { - log.Error("build converted parameter value failed: %v", err) - return err - } - prepare.BindVars[key] = sqltypes.ValueBindVariable(v) + prepare.BindVars[key] = sqltypes.BytesBindVariable(chunk) } case ComStmtClose: stmtID, ok := c.parseComStmtClose(data) @@ -1222,8 +1214,9 @@ func ParseErrorPacket(data []byte) error { return NewSQLError(int(code), string(sqlState), "%v", msg) } -func (conn *Conn) GetTLSClientCerts() []*x509.Certificate { - if tlsConn, ok := conn.conn.(*tls.Conn); ok { +// This method gets TLS certificates +func (c *Conn) GetTLSClientCerts() []*x509.Certificate { + if tlsConn, ok := c.conn.(*tls.Conn); ok { return tlsConn.ConnectionState().PeerCertificates } return nil diff --git a/go/mysql/query.go b/go/mysql/query.go index 2ba90f4ead7..56a3bf9729a 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -1118,6 +1118,7 @@ func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) err func (c *Conn) writeBinaryRows(result *sqltypes.Result) error { for _, row := range result.Rows { if err := c.writeBinaryRow(result.Fields, row); err != nil { + c.recycleWritePacket() return err } } diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index e7be859ccbd..8a907549c67 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -27,7 +27,6 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - "vitess.io/vitess/go/vt/sqlparser" ) // Utility function to write sql query as packets to test parseComPrepare @@ -43,11 +42,6 @@ func MockQueryPackets(t *testing.T, query string) []byte { func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { sql := "select * from test_table where id = ?" - statement, err := sqlparser.Parse(sql) - if err != nil { - t.Fatalf("Sql parinsg failed: %v", err) - } - result := &sqltypes.Result{ Fields: []*querypb.Field{ { @@ -66,7 +60,6 @@ func MockPrepareData(t *testing.T) (*PrepareData, *sqltypes.Result) { prepare := &PrepareData{ StatementID: 18, PrepareStmt: sql, - ParsedStmt: &statement, ParamsCount: 1, ParamsType: []int32{263}, ColumnNames: []string{"id"}, diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 6a40e6751d0..c557f132941 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1460,39 +1460,7 @@ func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql st switch stmtType { case sqlparser.StmtSelect: return e.handlePrepare(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, logStats) - case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: - safeSession := safeSession - - mustCommit := false - if safeSession.Autocommit && !safeSession.InTransaction() { - mustCommit = true - if err := e.txConn.Begin(ctx, safeSession); err != nil { - return nil, err - } - // The defer acts as a failsafe. If commit was successful, - // the rollback will be a no-op. - defer e.txConn.Rollback(ctx, safeSession) - } - - // The SetAutocommitable flag should be same as mustCommit. - // If we started a transaction because of autocommit, then mustCommit - // will be true, which means that we can autocommit. If we were already - // in a transaction, it means that the app started it, or we are being - // called recursively. If so, we cannot autocommit because whatever we - // do is likely not final. - // The control flow is such that autocommitable can only be turned on - // at the beginning, but never after. - safeSession.SetAutocommittable(mustCommit) - - if mustCommit { - commitStart := time.Now() - if err = e.txConn.Commit(ctx, safeSession); err != nil { - return nil, err - } - logStats.CommitTime = time.Since(commitStart) - } - return &sqltypes.Result{}, nil - case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, + case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete, sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtComment: return &sqltypes.Result{}, nil case sqlparser.StmtShow: @@ -1532,7 +1500,7 @@ func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, } // Check if there was partial DML execution. If so, rollback the transaction. - if err != nil && safeSession.InTransaction() && vcursor.hasPartialDML { + if err != nil { _ = e.txConn.Rollback(ctx, safeSession) err = vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction rolled back due to partial DML execution: %v", err) } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 0191e90a38d..3fe0f58ae95 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -233,7 +233,6 @@ func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, callback func(* }, statement) prepare := c.PrepareData[c.StatementID] - prepare.ParsedStmt = &statement if paramsCount > 0 { prepare.ParamsCount = paramsCount diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 6b50dc23631..a9da6984499 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -844,7 +844,7 @@ func (vtg *VTGate) Prepare(ctx context.Context, session *vtgatepb.Session, sql s goto handleError } - qr, err = vtg.executor.Prepare(ctx, "Execute", NewSafeSession(session), sql, bindVariables) + qr, err = vtg.executor.Prepare(ctx, "Prepare", NewSafeSession(session), sql, bindVariables) if err == nil { vtg.rowsReturned.Add(statsKey, int64(len(qr.Rows))) return session, qr, nil From 017cf57f8879ddf75ae0f5ed38c1c377e02d037a Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Tue, 30 Jul 2019 12:21:18 -0700 Subject: [PATCH 26/32] Quick fix Signed-off-by: Saif Alharthi --- go/mysql/conn.go | 25 ++++++++++++++++++++++++- go/mysql/query.go | 2 +- go/vt/vtgate/plugin_mysql_server.go | 27 --------------------------- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 39ac3317cf8..ab0143f9c4b 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -847,12 +847,35 @@ func (c *Conn) handleNextCommand(handler Handler) error { PrepareStmt: queries[0], } + statement, err := sqlparser.ParseStrictDDL(query) + if err != nil { + return err + } + + paramsCount := uint16(0) + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { + switch node := node.(type) { + case *sqlparser.SQLVal: + if strings.HasPrefix(string(node.Val), ":v") { + paramsCount++ + } + } + return true, nil + }, statement) + + if paramsCount > 0 { + prepare.ParamsCount = paramsCount + prepare.ParamsType = make([]int32, paramsCount) + prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) + } + c.PrepareData[c.StatementID] = prepare fieldSent := false // sendFinished is set if the response should just be an OK packet. sendFinished := false - err := handler.ComPrepare(c, queries[0], func(qr *sqltypes.Result) error { + // TODO(saifalharthi) change the function to return a field. + err = handler.ComPrepare(c, queries[0], func(qr *sqltypes.Result) error { if sendFinished { // Failsafe: Unreachable if server is well-behaved. return io.EOF diff --git a/go/mysql/query.go b/go/mysql/query.go index 56a3bf9729a..dc328998eca 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -1101,6 +1101,7 @@ func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) err } else { v, err := val2MySQL(val) if err != nil { + c.recycleWritePacket() return fmt.Errorf("internal value %v to MySQL value error: %v", val, err) } pos += copy(data[pos:], v) @@ -1118,7 +1119,6 @@ func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) err func (c *Conn) writeBinaryRows(result *sqltypes.Result) error { for _, row := range result.Rows { if err := c.writeBinaryRow(result.Fields, row); err != nil { - c.recycleWritePacket() return err } } diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 3fe0f58ae95..5408756f3c1 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -21,7 +21,6 @@ import ( "fmt" "net" "os" - "strings" "sync/atomic" "syscall" "time" @@ -35,7 +34,6 @@ import ( "vitess.io/vitess/go/vt/callinfo" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/servenv" - "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vttls" querypb "vitess.io/vitess/go/vt/proto/query" @@ -215,31 +213,6 @@ func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, callback func(* session.TargetString = c.SchemaName } - statement, err := sqlparser.ParseStrictDDL(query) - if err != nil { - err = mysql.NewSQLErrorFromError(err) - return err - } - - paramsCount := uint16(0) - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { - switch node := node.(type) { - case *sqlparser.SQLVal: - if strings.HasPrefix(string(node.Val), ":v") { - paramsCount++ - } - } - return true, nil - }, statement) - - prepare := c.PrepareData[c.StatementID] - - if paramsCount > 0 { - prepare.ParamsCount = paramsCount - prepare.ParamsType = make([]int32, paramsCount) - prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) - } - session, result, err := vh.vtg.Prepare(ctx, session, query, make(map[string]*querypb.BindVariable)) c.ClientData = session err = mysql.NewSQLErrorFromError(err) From 03217bd2f263585842ac24b351427f8e3c3946bf Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 31 Jul 2019 19:01:44 -0700 Subject: [PATCH 27/32] Document test Signed-off-by: Saif Alharthi --- test/prepared_statement_test.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/prepared_statement_test.py b/test/prepared_statement_test.py index a76f55f8196..653ba34099a 100755 --- a/test/prepared_statement_test.py +++ b/test/prepared_statement_test.py @@ -163,11 +163,11 @@ def tearDownModule(): ) Engine=InnoDB''' -class TestMySQL(unittest.TestCase): - """This test makes sure the MySQL server connector is correct. +class TestPreparedStatements(unittest.TestCase): + """This test makes sure that prepared statements is working correctly. """ - def test_mysql_connector(self): + def test_prepared_statements(self): with open(table_acl_config, 'w') as fd: fd.write("""{ "table_groups": [ @@ -233,6 +233,7 @@ def test_mysql_connector(self): text_value = "text" * 100 # Large text value largeComment = 'L' * ((4 * 1024) + 1) # Large blob + # Set up the values for the prepared statement cursor = conn.cursor(cursor_class=MySQLCursorPrepared) for i in range(1, 100): insert_values = (i, str(i) + "21", i * 100, 127, 1, 32767, 8388607, 2147483647, 2.55, 64.9,55.5, @@ -256,6 +257,7 @@ def test_mysql_connector(self): cursor.close() + # Update a row using prepared statements updated_text_value = "text_col_msg" updated_data_value = "updated" @@ -263,6 +265,7 @@ def test_mysql_connector(self): cursor.execute('update vt_prepare_stmt_test set data = %s , text_col = %s where id = %s', (updated_data_value, updated_text_value, 1)) cursor.close() + # Validate the update results cursor = conn.cursor(cursor_class=MySQLCursorPrepared) cursor.execute('select * from vt_prepare_stmt_test where id = %s', (1,)) result = cursor.fetchone() @@ -270,10 +273,12 @@ def test_mysql_connector(self): self.fail("Received incorrect values") cursor.close() + # Delete from table using prepared statements cursor = conn.cursor(cursor_class=MySQLCursorPrepared) cursor.execute('delete from vt_prepare_stmt_test where text_col = %s', (text_value,)) cursor.close() + # Validate Deletion cursor = conn.cursor(cursor_class=MySQLCursorPrepared) cursor.execute('select count(*) from vt_prepare_stmt_test') res = cursor.fetchone() From 0fedf799903264632e4e037c0b1d76df75ebf495 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 7 Aug 2019 09:53:30 -0700 Subject: [PATCH 28/32] Fix Tests Signed-off-by: Saif Alharthi --- docker/test/run.sh | 2 +- go/mysql/conn.go | 28 +++++++--------- go/mysql/fakesqldb/server.go | 6 ++-- go/mysql/query.go | 33 +++++++++---------- go/mysql/query_test.go | 6 ++-- go/mysql/server.go | 3 +- go/mysql/server_test.go | 4 +-- go/vt/vtgate/executor.go | 23 +++++-------- go/vt/vtgate/executor_framework_test.go | 2 +- go/vt/vtgate/plugin_mysql_server.go | 8 ++--- go/vt/vtgate/plugin_mysql_server_test.go | 5 +-- go/vt/vtgate/vtgate.go | 8 ++--- go/vt/vtqueryserver/plugin_mysql_server.go | 4 +-- .../vtqueryserver/plugin_mysql_server_test.go | 5 +-- 14 files changed, 64 insertions(+), 73 deletions(-) diff --git a/docker/test/run.sh b/docker/test/run.sh index 15b305d2206..23082a1c7e0 100755 --- a/docker/test/run.sh +++ b/docker/test/run.sh @@ -129,7 +129,7 @@ chmod -R o=g . # "Failed to move to new namespace: PID namespaces supported, Network namespace supported, but failed: errno = Operation not permitted" args="$args --cap-add=SYS_ADMIN" -args="$args -v /dev/log:/dev/log" +args="$args -v /private/var/run:/dev/log" args="$args -v $PWD:/tmp/src" # Share maven dependency cache so they don't have to be redownloaded every time. diff --git a/go/mysql/conn.go b/go/mysql/conn.go index ab0143f9c4b..e2d7f3ce382 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -841,6 +841,7 @@ func (c *Conn) handleNextCommand(handler Handler) error { return fmt.Errorf("can not prepare multiple statements") } + // Popoulate PrepareData c.StatementID++ prepare := &PrepareData{ StatementID: c.StatementID, @@ -872,24 +873,10 @@ func (c *Conn) handleNextCommand(handler Handler) error { c.PrepareData[c.StatementID] = prepare fieldSent := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false - // TODO(saifalharthi) change the function to return a field. - err = handler.ComPrepare(c, queries[0], func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } - if !fieldSent { - fieldSent = true - if err := c.writePrepare(qr, c.PrepareData[c.StatementID]); err != nil { - return err - } - } + // TODO(saifalharthi) change the function to return a field. + fld, err := handler.ComPrepare(c, queries[0]) - return nil - }) if err != nil { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. @@ -897,6 +884,13 @@ func (c *Conn) handleNextCommand(handler Handler) error { return werr } + if !fieldSent { + fieldSent = true + if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { + return err + } + } + delete(c.PrepareData, c.StatementID) return nil } @@ -1237,7 +1231,7 @@ func ParseErrorPacket(data []byte) error { return NewSQLError(int(code), string(sqlState), "%v", msg) } -// This method gets TLS certificates +// GetTLSClientCerts gets TLS certificates. func (c *Conn) GetTLSClientCerts() []*x509.Certificate { if tlsConn, ok := c.conn.(*tls.Conn); ok { return tlsConn.ConnectionState().PeerCertificates diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 77da9ea8905..14d9869095b 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -31,6 +31,8 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" + + querypb "vitess.io/vitess/go/vt/proto/query" ) const appendEntry = -1 @@ -433,8 +435,8 @@ func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) { } // ComPrepare is part of the mysql.Handler interface. -func (db *DB) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { - return nil +func (db *DB) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Field, error) { + return nil, nil } // ComStmtExecute is part of the mysql.Handler interface. diff --git a/go/mysql/query.go b/go/mysql/query.go index dc328998eca..158ae3ea308 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -1002,11 +1002,11 @@ func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warn } // writePrepare writes a prepare query response to the wire. -func (c *Conn) writePrepare(result *sqltypes.Result, prepare *PrepareData) error { +func (c *Conn) writePrepare(fld []*querypb.Field, prepare *PrepareData) error { paramsCount := prepare.ParamsCount columnCount := 0 - if result != nil { - columnCount = len(result.Fields) + if len(fld) != 0 { + columnCount = len(fld) } if columnCount > 0 { prepare.ColumnNames = make([]string, columnCount) @@ -1045,23 +1045,20 @@ func (c *Conn) writePrepare(result *sqltypes.Result, prepare *PrepareData) error } } - if result != nil { - // Now send each Field. - for i, field := range result.Fields { - field.Name = strings.Replace(field.Name, "'?'", "?", -1) - prepare.ColumnNames[i] = field.Name - if err := c.writeColumnDefinition(field); err != nil { - return err - } + for i, field := range fld { + field.Name = strings.Replace(field.Name, "'?'", "?", -1) + prepare.ColumnNames[i] = field.Name + if err := c.writeColumnDefinition(field); err != nil { + return err } + } - if columnCount > 0 { - // Now send an EOF packet. - if c.Capabilities&CapabilityClientDeprecateEOF == 0 { - // With CapabilityClientDeprecateEOF, we do not send this EOF. - if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { - return err - } + if columnCount > 0 { + // Now send an EOF packet. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + // With CapabilityClientDeprecateEOF, we do not send this EOF. + if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { + return err } } } diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 8a907549c67..b41524fc237 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -148,7 +148,7 @@ func TestComStmtPrepare(t *testing.T) { sConn.PrepareData[prepare.StatementID] = prepare // write the response to the client - if err := sConn.writePrepare(result, prepare); err != nil { + if err := sConn.writePrepare(result.Fields, prepare); err != nil { t.Fatalf("sConn.writePrepare failed: %v", err) } @@ -172,7 +172,7 @@ func TestComStmtSendLongData(t *testing.T) { prepare, result := MockPrepareData(t) cConn.PrepareData = make(map[uint32]*PrepareData) cConn.PrepareData[prepare.StatementID] = prepare - if err := cConn.writePrepare(result, prepare); err != nil { + if err := cConn.writePrepare(result.Fields, prepare); err != nil { t.Fatalf("writePrepare failed: %v", err) } @@ -233,7 +233,7 @@ func TestComStmtClose(t *testing.T) { prepare, result := MockPrepareData(t) cConn.PrepareData = make(map[uint32]*PrepareData) cConn.PrepareData[prepare.StatementID] = prepare - if err := cConn.writePrepare(result, prepare); err != nil { + if err := cConn.writePrepare(result.Fields, prepare); err != nil { t.Fatalf("writePrepare failed: %v", err) } diff --git a/go/mysql/server.go b/go/mysql/server.go index fb5733c6d24..5de219ff476 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -29,6 +29,7 @@ import ( "vitess.io/vitess/go/sync2" "vitess.io/vitess/go/tb" "vitess.io/vitess/go/vt/log" + querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" ) @@ -96,7 +97,7 @@ type Handler interface { // ComPrepare is called when a connection receives a prepared // statement query. - ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error + ComPrepare(c *Conn, query string) ([]*querypb.Field, error) // ComStmtExecute is called when a connection receives a statement // execute query. diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 9609268ec14..49456128c9f 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -172,8 +172,8 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R return nil } -func (th *testHandler) ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error { - return nil +func (th *testHandler) ComPrepare(c *Conn, query string) ([]*querypb.Field, error) { + return nil, nil } func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 8e9f20568ff..5a4df07e611 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1419,9 +1419,9 @@ func buildVarCharRow(values ...string) []sqltypes.Value { } // Prepare executes a prepare statements. -func (e *Executor) Prepare(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) { +func (e *Executor) Prepare(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (fld []*querypb.Field, err error) { logStats := NewLogStats(ctx, method, sql, bindVars) - result, err = e.prepare(ctx, safeSession, sql, bindVars, logStats) + fld, err = e.prepare(ctx, safeSession, sql, bindVars, logStats) logStats.Error = err // The mysql plugin runs an implicit rollback whenever a connection closes. @@ -1430,10 +1430,10 @@ func (e *Executor) Prepare(ctx context.Context, method string, safeSession *Safe if !(logStats.StmtType == "ROLLBACK" && logStats.ShardQueries == 0) { logStats.Send() } - return result, err + return fld, err } -func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *LogStats) (*sqltypes.Result, error) { +func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *LogStats) ([]*querypb.Field, error) { // Start an implicit transaction if necessary. if !safeSession.Autocommit && !safeSession.InTransaction() { if err := e.txConn.Begin(ctx, safeSession); err != nil { @@ -1472,14 +1472,15 @@ func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql st return e.handlePrepare(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, logStats) case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete, sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtComment: - return &sqltypes.Result{}, nil + return nil, nil case sqlparser.StmtShow: - return e.handleShow(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) + res, err := e.handleShow(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) + return res.Fields, err } return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unrecognized statement: %s", sql) } -func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, destKeyspace string, destTabletType topodatapb.TabletType, logStats *LogStats) (*sqltypes.Result, error) { +func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, destKeyspace string, destTabletType topodatapb.TabletType, logStats *LogStats) ([]*querypb.Field, error) { // V3 mode. query, comments := sqlparser.SplitMarginComments(sql) vcursor := newVCursorImpl(ctx, safeSession, destKeyspace, destTabletType, comments, e, logStats) @@ -1509,13 +1510,7 @@ func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, logStats.RowsAffected = qr.RowsAffected } - // Check if there was partial DML execution. If so, rollback the transaction. - if err != nil { - _ = e.txConn.Rollback(ctx, safeSession) - err = vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction rolled back due to partial DML execution: %v", err) - } - plan.AddStats(1, time.Since(logStats.StartTime), uint64(logStats.ShardQueries), logStats.RowsAffected, errCount) - return qr, err + return qr.Fields, err } diff --git a/go/vt/vtgate/executor_framework_test.go b/go/vt/vtgate/executor_framework_test.go index 85ac2c5c3e6..88073f32f69 100644 --- a/go/vt/vtgate/executor_framework_test.go +++ b/go/vt/vtgate/executor_framework_test.go @@ -392,7 +392,7 @@ func executorExec(executor *Executor, sql string, bv map[string]*querypb.BindVar bv) } -func executorPrepare(executor *Executor, sql string, bv map[string]*querypb.BindVariable) (*sqltypes.Result, error) { +func executorPrepare(executor *Executor, sql string, bv map[string]*querypb.BindVariable) ([]*querypb.Field, error) { return executor.Prepare( context.Background(), "TestExecute", diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 5408756f3c1..85240db4804 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -163,7 +163,7 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq } // ComPrepare is the handler for command prepare. -func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { +func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Field, error) { var ctx context.Context var cancel context.CancelFunc if *mysqlQueryTimeout != 0 { @@ -213,13 +213,13 @@ func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, callback func(* session.TargetString = c.SchemaName } - session, result, err := vh.vtg.Prepare(ctx, session, query, make(map[string]*querypb.BindVariable)) + session, fld, err := vh.vtg.Prepare(ctx, session, query, make(map[string]*querypb.BindVariable)) c.ClientData = session err = mysql.NewSQLErrorFromError(err) if err != nil { - return err + return nil, err } - return callback(result) + return fld, nil } func (vh *vtgateHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 07d03fa97b2..c190f053e83 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -26,6 +26,7 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) type testHandler struct { @@ -43,8 +44,8 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } -func (th *testHandler) ComPrepare(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error { - return nil +func (th *testHandler) ComPrepare(c *mysql.Conn, q string) ([]*querypb.Field, error) { + return nil, nil } func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index cc2a68d3222..5b5094d9626 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -834,7 +834,7 @@ func (vtg *VTGate) ResolveTransaction(ctx context.Context, dtid string) error { } // Prepare supports non-streaming prepare statement query with multi shards -func (vtg *VTGate) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (newSession *vtgatepb.Session, qr *sqltypes.Result, err error) { +func (vtg *VTGate) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (newSession *vtgatepb.Session, fld []*querypb.Field, err error) { // In this context, we don't care if we can't fully parse destination destKeyspace, destTabletType, _, _ := vtg.executor.ParseDestinationTarget(session.TargetString) statsKey := []string{"Execute", destKeyspace, topoproto.TabletTypeLString(destTabletType)} @@ -845,10 +845,10 @@ func (vtg *VTGate) Prepare(ctx context.Context, session *vtgatepb.Session, sql s goto handleError } - qr, err = vtg.executor.Prepare(ctx, "Prepare", NewSafeSession(session), sql, bindVariables) + fld, err = vtg.executor.Prepare(ctx, "Prepare", NewSafeSession(session), sql, bindVariables) if err == nil { - vtg.rowsReturned.Add(statsKey, int64(len(qr.Rows))) - return session, qr, nil + vtg.rowsReturned.Add(statsKey, int64(len(fld))) + return session, fld, nil } handleError: diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index 6e86a21a172..a900472d6bf 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -136,8 +136,8 @@ func (mh *proxyHandler) WarningCount(c *mysql.Conn) uint16 { return 0 } -func (mh *proxyHandler) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { - return nil +func (mh *proxyHandler) ComPrepare(c *mysql.Conn, query string) ([]*querypb.Field, error) { + return nil, nil } func (mh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { diff --git a/go/vt/vtqueryserver/plugin_mysql_server_test.go b/go/vt/vtqueryserver/plugin_mysql_server_test.go index 936962075e7..6c538b80a37 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server_test.go +++ b/go/vt/vtqueryserver/plugin_mysql_server_test.go @@ -26,6 +26,7 @@ import ( "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" ) type testHandler struct { @@ -43,8 +44,8 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } -func (th *testHandler) ComPrepare(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error { - return nil +func (th *testHandler) ComPrepare(c *mysql.Conn, q string) ([]*querypb.Field, error) { + return nil, nil } func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { From 3f67a4a4bae52d79ad31baf129904a5bf09a8b2a Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Wed, 7 Aug 2019 10:14:50 -0700 Subject: [PATCH 29/32] Fix mount path Signed-off-by: Saif Alharthi --- docker/test/run.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/test/run.sh b/docker/test/run.sh index a4b2c342965..dc59fad26f4 100755 --- a/docker/test/run.sh +++ b/docker/test/run.sh @@ -129,7 +129,7 @@ chmod -R o=g . # "Failed to move to new namespace: PID namespaces supported, Network namespace supported, but failed: errno = Operation not permitted" args="$args --cap-add=SYS_ADMIN" -args="$args -v /private/var/run:/dev/log" +args="$args -v /dev/log:/dev/log" args="$args -v $PWD:/tmp/src" # Share maven dependency cache so they don't have to be redownloaded every time. From 58a810f82363d83289e97462f206b248ec116371 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 8 Aug 2019 22:05:45 -0700 Subject: [PATCH 30/32] Fixed tests and bug in ComPrepare Signed-off-by: Saif Alharthi --- bootstrap.sh | 2 ++ go/mysql/conn.go | 14 ++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/bootstrap.sh b/bootstrap.sh index 5a2add5b084..51dcab0fd23 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -361,6 +361,8 @@ if [ "$BUILD_TESTS" == 1 ] ; then echo "$MYSQL_FLAVOR" > "$VTROOT/dist/MYSQL_FLAVOR" fi +PYTHONPATH='' $PIP install mysql-connector-python + # # 4. Installation of development related steps e.g. creating Git hooks. # diff --git a/go/mysql/conn.go b/go/mysql/conn.go index e2d7f3ce382..2e4450c0ca7 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -883,17 +883,15 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) return werr } + } - if !fieldSent { - fieldSent = true - if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { - return err - } + if !fieldSent { + fieldSent = true + if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { + return err } - - delete(c.PrepareData, c.StatementID) - return nil } + case ComStmtExecute: queryStart := time.Now() stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) From d90e7233732aa9da1af039173ec6d2abf5a9dc96 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 8 Aug 2019 23:02:53 -0700 Subject: [PATCH 31/32] Address comments Signed-off-by: Saif Alharthi --- go/mysql/conn.go | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 2e4450c0ca7..42b8215c86c 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -872,9 +872,6 @@ func (c *Conn) handleNextCommand(handler Handler) error { c.PrepareData[c.StatementID] = prepare - fieldSent := false - - // TODO(saifalharthi) change the function to return a field. fld, err := handler.ComPrepare(c, queries[0]) if err != nil { @@ -883,13 +880,11 @@ func (c *Conn) handleNextCommand(handler Handler) error { log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) return werr } + return nil } - if !fieldSent { - fieldSent = true - if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { - return err - } + if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { + return err } case ComStmtExecute: From 112187e0a5ac6afd0d257c225fd3fceb2b9237b6 Mon Sep 17 00:00:00 2001 From: Saif Alharthi Date: Thu, 8 Aug 2019 23:22:22 -0700 Subject: [PATCH 32/32] Added end to end test fail ComPrepare and make sure other queries do not get affacted Signed-off-by: Saif Alharthi --- test/prepared_statement_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/prepared_statement_test.py b/test/prepared_statement_test.py index 653ba34099a..963f2550c6f 100755 --- a/test/prepared_statement_test.py +++ b/test/prepared_statement_test.py @@ -27,6 +27,7 @@ import mysql.connector from mysql.connector import FieldType from mysql.connector.cursor import MySQLCursorPrepared +from mysql.connector.errors import Error import environment import utils @@ -229,6 +230,16 @@ def test_prepared_statements(self): cursor.fetchone() cursor.close() + cursor = conn.cursor() + try: + cursor.execute('selet * from vt_prepare_stmt_test', {}) + cursor.close() + except mysql.connector.Error as err: + if err.errno == 1105: + print "Captured the error" + else: + raise + # Insert several rows using prepared statements text_value = "text" * 100 # Large text value largeComment = 'L' * ((4 * 1024) + 1) # Large blob