diff --git a/changelog/22.0/22.0.0/summary.md b/changelog/22.0/22.0.0/summary.md index 51a9fc9dd41..401b54eb4fd 100644 --- a/changelog/22.0/22.0.0/summary.md +++ b/changelog/22.0/22.0.0/summary.md @@ -27,6 +27,7 @@ - [LAST_INSERT_ID(x)](#last-insert-id) - [Maximum Idle Connections in the Pool](#max-idle-connections) - [Filtering Query logs on Error](#query-logs) + - [MultiQuery RPC in vtgate](#multiquery) - **[Optimization](#optimization)** - [Prepared Statement](#prepared-statement) - **[RPC Changes](#rpc-changes)** @@ -280,6 +281,14 @@ The `querylog-mode` setting can be configured to `error` to log only queries tha --- +#### MultiQuery RPC in vtgate + +New RPCs in vtgate have been added that allow users to pass multiple queries in a single sql string. It behaves the same way MySQL does where-in multiple result sets for the queries are returned in the same order as the queries were passed until an error is encountered. The new RPCs are `ExecuteMulti` and `StreamExecuteMulti`. + +A new flag `--mysql-server-multi-query-protocol` has also been added that makes the server use this new implementation. This flag is set to `false` by default, so the old implementation is used by default. The new implementation is more efficient and allows for better performance when executing multiple queries in a single RPC call. + +--- + ### Optimization #### Prepared Statement diff --git a/go/cmd/vtgateclienttest/services/callerid.go b/go/cmd/vtgateclienttest/services/callerid.go index f66644d3996..6c708199056 100644 --- a/go/cmd/vtgateclienttest/services/callerid.go +++ b/go/cmd/vtgateclienttest/services/callerid.go @@ -29,6 +29,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/vtgateservice" ) @@ -104,3 +105,40 @@ func (c *callerIDClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservi } return c.fallbackClient.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback) } + +// ExecuteMulti is part of the VTGateService interface +func (c *callerIDClient) ExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + queries, err := sqlparser.NewTestParser().SplitStatementToPieces(sqlString) + if err != nil { + return session, nil, err + } + var result *sqltypes.Result + for _, query := range queries { + session, result, err = c.Execute(ctx, mysqlCtx, session, query, nil, false) + if err != nil { + return session, qrs, err + } + qrs = append(qrs, result) + } + return session, qrs, nil +} + +// StreamExecuteMulti is part of the VTGateService interface +func (c *callerIDClient) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + queries, err := sqlparser.NewTestParser().SplitStatementToPieces(sqlString) + if err != nil { + return session, err + } + for idx, query := range queries { + firstPacket := true + session, err = c.StreamExecute(ctx, mysqlCtx, session, query, nil, func(result *sqltypes.Result) error { + err = callback(sqltypes.QueryResponse{QueryResult: result}, idx < len(queries)-1, firstPacket) + firstPacket = false + return err + }) + if err != nil { + return session, err + } + } + return session, nil +} diff --git a/go/cmd/vtgateclienttest/services/echo.go b/go/cmd/vtgateclienttest/services/echo.go index 7ffe5b1cfd5..1f8b72c408b 100644 --- a/go/cmd/vtgateclienttest/services/echo.go +++ b/go/cmd/vtgateclienttest/services/echo.go @@ -130,6 +130,18 @@ func (c *echoClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.M return c.fallbackClient.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback) } +// ExecuteMulti is part of the VTGateService interface +func (c *echoClient) ExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + // Look at https://github.com/vitessio/vitess/pull/18059 for details on how to implement this. + panic("unimplemented") +} + +// StreamExecuteMulti is part of the VTGateService interface +func (c *echoClient) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + // Look at https://github.com/vitessio/vitess/pull/18059 for details on how to implement this. + panic("unimplemented") +} + func (c *echoClient) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) { if len(sqlList) > 0 && strings.HasPrefix(sqlList[0], EchoPrefix) { var queryResponse []sqltypes.QueryResponse diff --git a/go/cmd/vtgateclienttest/services/errors.go b/go/cmd/vtgateclienttest/services/errors.go index f1438a105e4..39d4d0b1e2b 100644 --- a/go/cmd/vtgateclienttest/services/errors.go +++ b/go/cmd/vtgateclienttest/services/errors.go @@ -146,6 +146,18 @@ func (c *errorClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservice. return c.fallbackClient.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback) } +// ExecuteMulti is part of the VTGateService interface +func (c *errorClient) ExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + // Look at https://github.com/vitessio/vitess/pull/18059 for details on how to implement this. + panic("unimplemented") +} + +// StreamExecuteMulti is part of the VTGateService interface +func (c *errorClient) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + // Look at https://github.com/vitessio/vitess/pull/18059 for details on how to implement this. + panic("unimplemented") +} + func (c *errorClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) { if err := requestToPartialError(sql, session); err != nil { return session, nil, 0, err diff --git a/go/cmd/vtgateclienttest/services/fallback.go b/go/cmd/vtgateclienttest/services/fallback.go index 00ad33776df..085652bdeb0 100644 --- a/go/cmd/vtgateclienttest/services/fallback.go +++ b/go/cmd/vtgateclienttest/services/fallback.go @@ -59,6 +59,14 @@ func (c fallbackClient) StreamExecute(ctx context.Context, mysqlCtx vtgateservic return c.fallback.StreamExecute(ctx, mysqlCtx, session, sql, bindVariables, callback) } +func (c fallbackClient) ExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + return c.fallback.ExecuteMulti(ctx, mysqlCtx, session, sqlString) +} + +func (c fallbackClient) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + return c.fallback.StreamExecuteMulti(ctx, mysqlCtx, session, sqlString, callback) +} + func (c fallbackClient) Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) { return c.fallback.Prepare(ctx, session, sql) } diff --git a/go/cmd/vtgateclienttest/services/terminal.go b/go/cmd/vtgateclienttest/services/terminal.go index ad1937566f1..e6a853ae306 100644 --- a/go/cmd/vtgateclienttest/services/terminal.go +++ b/go/cmd/vtgateclienttest/services/terminal.go @@ -74,6 +74,14 @@ func (c *terminalClient) Prepare(ctx context.Context, session *vtgatepb.Session, return session, nil, 0, errTerminal } +func (c *terminalClient) ExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + return session, nil, errTerminal +} + +func (c *terminalClient) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + return session, errTerminal +} + func (c *terminalClient) CloseSession(ctx context.Context, session *vtgatepb.Session) error { return errTerminal } diff --git a/go/flags/endtoend/vtcombo.txt b/go/flags/endtoend/vtcombo.txt index cf382598cae..eac59edb198 100644 --- a/go/flags/endtoend/vtcombo.txt +++ b/go/flags/endtoend/vtcombo.txt @@ -219,6 +219,7 @@ Flags: --mycnf_tmp_dir string mysql tmp directory --mysql-server-drain-onterm If set, the server waits for --onterm_timeout for already connected clients to complete their in flight work --mysql-server-keepalive-period duration TCP period between keep-alives + --mysql-server-multi-query-protocol If set, the server will use the new implementation of handling queries where-in multiple queries are sent together. --mysql-server-pool-conn-read-buffers If set, the server will pool incoming connection read buffers --mysql-shell-backup-location string location where the backup will be stored --mysql-shell-dump-flags string flags to pass to mysql shell dump utility. This should be a JSON string and will be saved in the MANIFEST (default "{\"threads\": 4}") diff --git a/go/flags/endtoend/vtgate.txt b/go/flags/endtoend/vtgate.txt index 1a4e6ace312..63b4d8e1d0d 100644 --- a/go/flags/endtoend/vtgate.txt +++ b/go/flags/endtoend/vtgate.txt @@ -123,6 +123,7 @@ Flags: --min_number_serving_vttablets int The minimum number of vttablets for each replicating tablet_type (e.g. replica, rdonly) that will be continue to be used even with replication lag above discovery_low_replication_lag, but still below discovery_high_replication_lag_minimum_serving. (default 2) --mysql-server-drain-onterm If set, the server waits for --onterm_timeout for already connected clients to complete their in flight work --mysql-server-keepalive-period duration TCP period between keep-alives + --mysql-server-multi-query-protocol If set, the server will use the new implementation of handling queries where-in multiple queries are sent together. --mysql-server-pool-conn-read-buffers If set, the server will pool incoming connection read buffers --mysql_allow_clear_text_without_tls If set, the server will allow the use of a clear text password over non-SSL connections. --mysql_auth_server_impl string Which auth server implementation to use. Options: none, ldap, clientcert, static, vault. (default "static") diff --git a/go/mysql/conn.go b/go/mysql/conn.go index cfe65e07166..07960ec4146 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -30,6 +30,8 @@ import ( "sync/atomic" "time" + "github.com/spf13/pflag" + "vitess.io/vitess/go/bucketpool" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/sqlerror" @@ -38,6 +40,7 @@ import ( "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vterrors" ) @@ -67,6 +70,19 @@ const ( ephemeralRead ) +var ( + mysqlMultiQuery = false +) + +func registerConnFlags(fs *pflag.FlagSet) { + fs.BoolVar(&mysqlMultiQuery, "mysql-server-multi-query-protocol", mysqlMultiQuery, "If set, the server will use the new implementation of handling queries where-in multiple queries are sent together.") +} + +func init() { + servenv.OnParseFor("vtgate", registerConnFlags) + servenv.OnParseFor("vtcombo", registerConnFlags) +} + // A Getter has a Get() type Getter interface { Get() *querypb.VTGateCallerID @@ -914,6 +930,9 @@ func (c *Conn) handleNextCommand(handler Handler) bool { res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false) return res != connErr case ComQuery: + if mysqlMultiQuery { + return c.handleComQueryMulti(handler, data) + } return c.handleComQuery(handler, data) case ComPing: return c.handleComPing() @@ -1279,6 +1298,141 @@ func (c *Conn) handleComPing() bool { return true } +// handleComQueryMulti is a newer version of handleComQuery that uses +// the StreamExecuteMulti and ExecuteMulti RPC calls to push the splitting of statements +// down to Vtgate. +func (c *Conn) handleComQueryMulti(handler Handler, data []byte) (kontinue bool) { + c.startWriterBuffering() + defer func() { + if err := c.endWriterBuffering(); err != nil { + log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + kontinue = false + } + }() + + queryStart := time.Now() + query := c.parseComQuery(data) + c.recycleReadPacket() + + res := c.execQueryMulti(query, handler) + if res != execSuccess { + return res != connErr + } + + timings.Record(queryTimingKey, queryStart) + return true +} + +// execQueryMulti is a newer version of execQuery that uses +// the StreamExecuteMulti and ExecuteMulti RPC calls to push the splitting of statements +// down to Vtgate. +func (c *Conn) execQueryMulti(query string, handler Handler) execResult { + // needsEndPacket signifies whether we have need to send the last packet to the client + // for a given query. This is used to determine whether we should send an + // end packet after the query is done or not. Initially we don't need to send an end packet + // so we initialize this value to false. + needsEndPacket := false + callbackCalled := false + var res = execSuccess + + err := handler.ComQueryMulti(c, query, func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error { + callbackCalled = true + flag := c.StatusFlags + if more { + flag |= ServerMoreResultsExists + } + + // firstPacket tells us that this is the start of a new query result. + // If we haven't sent a last packet yet, we should send the end result packet. + if firstPacket && needsEndPacket { + if err := c.writeEndResult(true, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return err + } + } + + // We receive execution errors in a query as part of the QueryResponse. + // We check for those errors and send a error packet. If we are unable + // to send the error packet, then there is a connection error too. + if qr.QueryError != nil { + res = execErr + if !c.writeErrorPacketFromErrorAndLog(qr.QueryError) { + res = connErr + } + return nil + } + + if firstPacket { + // The first packet signifies the start of a new query result. + // So we reset the needsEndPacket variable to signify we haven't sent the last + // packet for this query. + needsEndPacket = true + if len(qr.QueryResult.Fields) == 0 { + + // A successful callback with no fields means that this was a + // DML or other write-only operation. + // + // We should not send any more packets after this, but make sure + // to extract the affected rows and last insert id from the result + // struct here since clients expect it. + ok := PacketOK{ + affectedRows: qr.QueryResult.RowsAffected, + lastInsertID: qr.QueryResult.InsertID, + statusFlags: flag, + warnings: handler.WarningCount(c), + info: "", + sessionStateData: qr.QueryResult.SessionStateChanges, + } + needsEndPacket = false + return c.writeOKPacket(&ok) + } + + if err := c.writeFields(qr.QueryResult); err != nil { + return err + } + } + + return c.writeRows(qr.QueryResult) + }) + + // If callback was not called, we expect an error. + // It is possible that we don't get a callback if some condition checks + // fail before the query starts execution. In this case, we need to write some + // error back. + if !callbackCalled { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = sqlerror.NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if !c.writeErrorPacketFromErrorAndLog(err) { + return connErr + } + return execErr + } + + if res != execSuccess { + // We failed during the stream itself. + return res + } + + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return connErr + } + + // If we haven't sent the final packet for the last query, we should send that too. + if needsEndPacket { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return connErr + } + } + + return execSuccess +} + var errEmptyStatement = sqlerror.NewSQLError(sqlerror.EREmptyQuery, sqlerror.SSClientError, "Query was empty") func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) { diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index 7520493dbfc..96f707eec5e 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -40,6 +40,7 @@ import ( "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" ) @@ -803,104 +804,165 @@ func TestIsEOFPacket(t *testing.T) { } func TestMultiStatementStopsOnError(t *testing.T) { - listener, sConn, cConn := createSocketPair(t) - sConn.Capabilities |= CapabilityClientMultiStatements + origMysqlMultiQuery := mysqlMultiQuery defer func() { - listener.Close() - sConn.Close() - cConn.Close() + mysqlMultiQuery = origMysqlMultiQuery }() + for _, b := range []bool{true, false} { + t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { + mysqlMultiQuery = b + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.WriteComQuery("error;select 2") + require.NoError(t, err) - err := cConn.WriteComQuery("error;select 2") - require.NoError(t, err) + // this handler will return results according to the query. In case the query contains "error" it will return an error + // panic if the query contains "panic" and it will return selectRowsResult in case of any other query + handler := &testRun{err: fmt.Errorf("execution failed")} + res := sConn.handleNextCommand(handler) + // Execution error will occur in this case because the query sent is error and testRun will throw an error. + // We should send an error packet but not close the connection. + require.True(t, res, "we should not break the connection because of execution errors") - // this handler will return results according to the query. In case the query contains "error" it will return an error - // panic if the query contains "panic" and it will return selectRowsResult in case of any other query - handler := &testRun{err: fmt.Errorf("execution failed")} - res := sConn.handleNextCommand(handler) - // Execution error will occur in this case because the query sent is error and testRun will throw an error. - // We should send an error packet but not close the connection. - require.True(t, res, "we should not break the connection because of execution errors") + data, err := cConn.ReadPacket() + require.NoError(t, err) + require.NotEmpty(t, data) + require.EqualValues(t, data[0], ErrPacket) // we should see the error here + }) + } +} - data, err := cConn.ReadPacket() - require.NoError(t, err) - require.NotEmpty(t, data) - require.EqualValues(t, data[0], ErrPacket) // we should see the error here +func TestEmptyQuery(t *testing.T) { + origMysqlMultiQuery := mysqlMultiQuery + defer func() { + mysqlMultiQuery = origMysqlMultiQuery + }() + for _, b := range []bool{true, false} { + t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { + mysqlMultiQuery = b + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.WriteComQuery("") + require.NoError(t, err) + + // this handler will return results according to the query. In case the query contains "error" it will return an error + // panic if the query contains "panic" and it will return selectRowsResult in case of any other query + handler := &testRun{err: sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")} + res := sConn.handleNextCommand(handler) + // The queries run will be an empty query; Even with the empty error, the connection should be fine + require.True(t, res, "we should not break the connection in case of no errors") + // Read the result and assert that we indeed see the error for empty query. + data, more, _, err := cConn.ReadQueryResult(100, true) + require.EqualError(t, err, "Query was empty (errno 1065) (sqlstate 42000)") + require.False(t, more) + require.Nil(t, data) + }) + } } func TestMultiStatement(t *testing.T) { - listener, sConn, cConn := createSocketPair(t) - sConn.Capabilities |= CapabilityClientMultiStatements + origMysqlMultiQuery := mysqlMultiQuery defer func() { - listener.Close() - sConn.Close() - cConn.Close() + mysqlMultiQuery = origMysqlMultiQuery }() + for _, b := range []bool{true, false} { + t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { + mysqlMultiQuery = b + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.WriteComQuery("select 1;select 2") + require.NoError(t, err) - err := cConn.WriteComQuery("select 1;select 2") - require.NoError(t, err) - - // this handler will return results according to the query. In case the query contains "error" it will return an error - // panic if the query contains "panic" and it will return selectRowsResult in case of any other query - handler := &testRun{err: sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")} - res := sConn.handleNextCommand(handler) - // The queries run will be select 1; and select 2; These queries do not return any errors, so the connection should still be open - require.True(t, res, "we should not break the connection in case of no errors") - // Read the result of the query and assert that it is indeed what we want. This will contain the result of the first query. - data, more, _, err := cConn.ReadQueryResult(100, true) - require.NoError(t, err) - // Since we executed 2 queries, there should be more results to be read - require.True(t, more) - require.True(t, data.Equal(selectRowsResult)) + // this handler will return results according to the query. In case the query contains "error" it will return an error + // panic if the query contains "panic" and it will return selectRowsResult in case of any other query + handler := &testRun{err: sqlerror.NewSQLError(sqlerror.CRMalformedPacket, sqlerror.SSUnknownSQLState, "cannot get column number")} + res := sConn.handleNextCommand(handler) + // The queries run will be select 1; and select 2; These queries do not return any errors, so the connection should still be open + require.True(t, res, "we should not break the connection in case of no errors") + // Read the result of the query and assert that it is indeed what we want. This will contain the result of the first query. + data, more, _, err := cConn.ReadQueryResult(100, true) + require.NoError(t, err) + // Since we executed 2 queries, there should be more results to be read + require.True(t, more) + require.True(t, data.Equal(selectRowsResult)) - // Read the results for the second query and verify the correctness - data, more, _, err = cConn.ReadQueryResult(100, true) - require.NoError(t, err) - // This was the final query run, so we expect that more should be false as there are no more queries. - require.False(t, more) - require.True(t, data.Equal(selectRowsResult)) + // Read the results for the second query and verify the correctness + data, more, _, err = cConn.ReadQueryResult(100, true) + require.NoError(t, err) + // This was the final query run, so we expect that more should be false as there are no more queries. + require.False(t, more) + require.True(t, data.Equal(selectRowsResult)) - // This time we run two queries fist of which will return an error - err = cConn.WriteComQuery("error;select 2") - require.NoError(t, err) + // This time we run two queries fist of which will return an error + err = cConn.WriteComQuery("error;select 2") + require.NoError(t, err) - res = sConn.handleNextCommand(handler) - // Even if the query returns an error we should not close the connection as it is an execution error - require.True(t, res, "we should not break the connection because of execution errors") + res = sConn.handleNextCommand(handler) + // Even if the query returns an error we should not close the connection as it is an execution error + require.True(t, res, "we should not break the connection because of execution errors") - // Read the result and assert that we indeed see the error that testRun throws. - data, more, _, err = cConn.ReadQueryResult(100, true) - require.EqualError(t, err, "cannot get column number (errno 2027) (sqlstate HY000)") - // In case of errors in a multi-statement, the following statements are not executed, therefore we want that more should be false - require.False(t, more) - require.Nil(t, data) + // Read the result and assert that we indeed see the error that testRun throws. + data, more, _, err = cConn.ReadQueryResult(100, true) + require.EqualError(t, err, "cannot get column number (errno 2027) (sqlstate HY000)") + // In case of errors in a multi-statement, the following statements are not executed, therefore we want that more should be false + require.False(t, more) + require.Nil(t, data) + }) + } } func TestMultiStatementOnSplitError(t *testing.T) { - listener, sConn, cConn := createSocketPair(t) - sConn.Capabilities |= CapabilityClientMultiStatements + origMysqlMultiQuery := mysqlMultiQuery defer func() { - listener.Close() - sConn.Close() - cConn.Close() + mysqlMultiQuery = origMysqlMultiQuery }() + for _, b := range []bool{true, false} { + t.Run(fmt.Sprintf("MultiQueryProtocol: %v", b), func(t *testing.T) { + mysqlMultiQuery = b + listener, sConn, cConn := createSocketPair(t) + sConn.Capabilities |= CapabilityClientMultiStatements + defer func() { + listener.Close() + sConn.Close() + cConn.Close() + }() + + err := cConn.WriteComQuery("broken>'query 1;parse'query 1;parse vtgate.Session.ShardSession - 25, // 1: vtgate.Session.options:type_name -> query.ExecuteOptions + 24, // 0: vtgate.Session.shard_sessions:type_name -> vtgate.Session.ShardSession + 29, // 1: vtgate.Session.options:type_name -> query.ExecuteOptions 0, // 2: vtgate.Session.transaction_mode:type_name -> vtgate.TransactionMode - 26, // 3: vtgate.Session.warnings:type_name -> query.QueryWarning - 20, // 4: vtgate.Session.pre_sessions:type_name -> vtgate.Session.ShardSession - 20, // 5: vtgate.Session.post_sessions:type_name -> vtgate.Session.ShardSession - 21, // 6: vtgate.Session.user_defined_variables:type_name -> vtgate.Session.UserDefinedVariablesEntry - 22, // 7: vtgate.Session.system_variables:type_name -> vtgate.Session.SystemVariablesEntry - 20, // 8: vtgate.Session.lock_session:type_name -> vtgate.Session.ShardSession + 30, // 3: vtgate.Session.warnings:type_name -> query.QueryWarning + 24, // 4: vtgate.Session.pre_sessions:type_name -> vtgate.Session.ShardSession + 24, // 5: vtgate.Session.post_sessions:type_name -> vtgate.Session.ShardSession + 25, // 6: vtgate.Session.user_defined_variables:type_name -> vtgate.Session.UserDefinedVariablesEntry + 26, // 7: vtgate.Session.system_variables:type_name -> vtgate.Session.SystemVariablesEntry + 24, // 8: vtgate.Session.lock_session:type_name -> vtgate.Session.ShardSession 4, // 9: vtgate.Session.read_after_write:type_name -> vtgate.ReadAfterWrite - 23, // 10: vtgate.Session.advisory_lock:type_name -> vtgate.Session.AdvisoryLockEntry - 24, // 11: vtgate.Session.prepare_statement:type_name -> vtgate.Session.PrepareStatementEntry - 27, // 12: vtgate.ExecuteRequest.caller_id:type_name -> vtrpc.CallerID - 2, // 13: vtgate.ExecuteRequest.session:type_name -> vtgate.Session - 28, // 14: vtgate.ExecuteRequest.query:type_name -> query.BoundQuery - 29, // 15: vtgate.ExecuteResponse.error:type_name -> vtrpc.RPCError - 2, // 16: vtgate.ExecuteResponse.session:type_name -> vtgate.Session - 30, // 17: vtgate.ExecuteResponse.result:type_name -> query.QueryResult - 27, // 18: vtgate.ExecuteBatchRequest.caller_id:type_name -> vtrpc.CallerID - 2, // 19: vtgate.ExecuteBatchRequest.session:type_name -> vtgate.Session - 28, // 20: vtgate.ExecuteBatchRequest.queries:type_name -> query.BoundQuery - 29, // 21: vtgate.ExecuteBatchResponse.error:type_name -> vtrpc.RPCError - 2, // 22: vtgate.ExecuteBatchResponse.session:type_name -> vtgate.Session - 31, // 23: vtgate.ExecuteBatchResponse.results:type_name -> query.ResultWithError - 27, // 24: vtgate.StreamExecuteRequest.caller_id:type_name -> vtrpc.CallerID - 28, // 25: vtgate.StreamExecuteRequest.query:type_name -> query.BoundQuery - 2, // 26: vtgate.StreamExecuteRequest.session:type_name -> vtgate.Session - 30, // 27: vtgate.StreamExecuteResponse.result:type_name -> query.QueryResult - 2, // 28: vtgate.StreamExecuteResponse.session:type_name -> vtgate.Session - 27, // 29: vtgate.ResolveTransactionRequest.caller_id:type_name -> vtrpc.CallerID - 27, // 30: vtgate.VStreamRequest.caller_id:type_name -> vtrpc.CallerID - 32, // 31: vtgate.VStreamRequest.tablet_type:type_name -> topodata.TabletType - 33, // 32: vtgate.VStreamRequest.vgtid:type_name -> binlogdata.VGtid - 34, // 33: vtgate.VStreamRequest.filter:type_name -> binlogdata.Filter - 13, // 34: vtgate.VStreamRequest.flags:type_name -> vtgate.VStreamFlags - 35, // 35: vtgate.VStreamResponse.events:type_name -> binlogdata.VEvent - 27, // 36: vtgate.PrepareRequest.caller_id:type_name -> vtrpc.CallerID - 2, // 37: vtgate.PrepareRequest.session:type_name -> vtgate.Session - 28, // 38: vtgate.PrepareRequest.query:type_name -> query.BoundQuery - 29, // 39: vtgate.PrepareResponse.error:type_name -> vtrpc.RPCError - 2, // 40: vtgate.PrepareResponse.session:type_name -> vtgate.Session - 36, // 41: vtgate.PrepareResponse.fields:type_name -> query.Field - 27, // 42: vtgate.CloseSessionRequest.caller_id:type_name -> vtrpc.CallerID - 2, // 43: vtgate.CloseSessionRequest.session:type_name -> vtgate.Session - 29, // 44: vtgate.CloseSessionResponse.error:type_name -> vtrpc.RPCError - 37, // 45: vtgate.Session.ShardSession.target:type_name -> query.Target - 38, // 46: vtgate.Session.ShardSession.tablet_alias:type_name -> topodata.TabletAlias - 39, // 47: vtgate.Session.UserDefinedVariablesEntry.value:type_name -> query.BindVariable - 3, // 48: vtgate.Session.PrepareStatementEntry.value:type_name -> vtgate.PrepareData - 49, // [49:49] is the sub-list for method output_type - 49, // [49:49] is the sub-list for method input_type - 49, // [49:49] is the sub-list for extension type_name - 49, // [49:49] is the sub-list for extension extendee - 0, // [0:49] is the sub-list for field type_name + 27, // 10: vtgate.Session.advisory_lock:type_name -> vtgate.Session.AdvisoryLockEntry + 28, // 11: vtgate.Session.prepare_statement:type_name -> vtgate.Session.PrepareStatementEntry + 31, // 12: vtgate.ExecuteMultiRequest.caller_id:type_name -> vtrpc.CallerID + 2, // 13: vtgate.ExecuteMultiRequest.session:type_name -> vtgate.Session + 32, // 14: vtgate.ExecuteMultiResponse.error:type_name -> vtrpc.RPCError + 2, // 15: vtgate.ExecuteMultiResponse.session:type_name -> vtgate.Session + 33, // 16: vtgate.ExecuteMultiResponse.results:type_name -> query.QueryResult + 31, // 17: vtgate.ExecuteRequest.caller_id:type_name -> vtrpc.CallerID + 2, // 18: vtgate.ExecuteRequest.session:type_name -> vtgate.Session + 34, // 19: vtgate.ExecuteRequest.query:type_name -> query.BoundQuery + 32, // 20: vtgate.ExecuteResponse.error:type_name -> vtrpc.RPCError + 2, // 21: vtgate.ExecuteResponse.session:type_name -> vtgate.Session + 33, // 22: vtgate.ExecuteResponse.result:type_name -> query.QueryResult + 31, // 23: vtgate.ExecuteBatchRequest.caller_id:type_name -> vtrpc.CallerID + 2, // 24: vtgate.ExecuteBatchRequest.session:type_name -> vtgate.Session + 34, // 25: vtgate.ExecuteBatchRequest.queries:type_name -> query.BoundQuery + 32, // 26: vtgate.ExecuteBatchResponse.error:type_name -> vtrpc.RPCError + 2, // 27: vtgate.ExecuteBatchResponse.session:type_name -> vtgate.Session + 35, // 28: vtgate.ExecuteBatchResponse.results:type_name -> query.ResultWithError + 31, // 29: vtgate.StreamExecuteRequest.caller_id:type_name -> vtrpc.CallerID + 34, // 30: vtgate.StreamExecuteRequest.query:type_name -> query.BoundQuery + 2, // 31: vtgate.StreamExecuteRequest.session:type_name -> vtgate.Session + 33, // 32: vtgate.StreamExecuteResponse.result:type_name -> query.QueryResult + 2, // 33: vtgate.StreamExecuteResponse.session:type_name -> vtgate.Session + 31, // 34: vtgate.StreamExecuteMultiRequest.caller_id:type_name -> vtrpc.CallerID + 2, // 35: vtgate.StreamExecuteMultiRequest.session:type_name -> vtgate.Session + 35, // 36: vtgate.StreamExecuteMultiResponse.result:type_name -> query.ResultWithError + 2, // 37: vtgate.StreamExecuteMultiResponse.session:type_name -> vtgate.Session + 31, // 38: vtgate.ResolveTransactionRequest.caller_id:type_name -> vtrpc.CallerID + 31, // 39: vtgate.VStreamRequest.caller_id:type_name -> vtrpc.CallerID + 36, // 40: vtgate.VStreamRequest.tablet_type:type_name -> topodata.TabletType + 37, // 41: vtgate.VStreamRequest.vgtid:type_name -> binlogdata.VGtid + 38, // 42: vtgate.VStreamRequest.filter:type_name -> binlogdata.Filter + 17, // 43: vtgate.VStreamRequest.flags:type_name -> vtgate.VStreamFlags + 39, // 44: vtgate.VStreamResponse.events:type_name -> binlogdata.VEvent + 31, // 45: vtgate.PrepareRequest.caller_id:type_name -> vtrpc.CallerID + 2, // 46: vtgate.PrepareRequest.session:type_name -> vtgate.Session + 34, // 47: vtgate.PrepareRequest.query:type_name -> query.BoundQuery + 32, // 48: vtgate.PrepareResponse.error:type_name -> vtrpc.RPCError + 2, // 49: vtgate.PrepareResponse.session:type_name -> vtgate.Session + 40, // 50: vtgate.PrepareResponse.fields:type_name -> query.Field + 31, // 51: vtgate.CloseSessionRequest.caller_id:type_name -> vtrpc.CallerID + 2, // 52: vtgate.CloseSessionRequest.session:type_name -> vtgate.Session + 32, // 53: vtgate.CloseSessionResponse.error:type_name -> vtrpc.RPCError + 41, // 54: vtgate.Session.ShardSession.target:type_name -> query.Target + 42, // 55: vtgate.Session.ShardSession.tablet_alias:type_name -> topodata.TabletAlias + 43, // 56: vtgate.Session.UserDefinedVariablesEntry.value:type_name -> query.BindVariable + 3, // 57: vtgate.Session.PrepareStatementEntry.value:type_name -> vtgate.PrepareData + 58, // [58:58] is the sub-list for method output_type + 58, // [58:58] is the sub-list for method input_type + 58, // [58:58] is the sub-list for extension type_name + 58, // [58:58] is the sub-list for extension extendee + 0, // [0:58] is the sub-list for field type_name } func init() { file_vtgate_proto_init() } @@ -2043,7 +2364,7 @@ func file_vtgate_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_vtgate_proto_rawDesc), len(file_vtgate_proto_rawDesc)), NumEnums: 2, - NumMessages: 23, + NumMessages: 27, NumExtensions: 0, NumServices: 0, }, diff --git a/go/vt/proto/vtgate/vtgate_vtproto.pb.go b/go/vt/proto/vtgate/vtgate_vtproto.pb.go index 874fb3b550a..3a3f05aa593 100644 --- a/go/vt/proto/vtgate/vtgate_vtproto.pb.go +++ b/go/vt/proto/vtgate/vtgate_vtproto.pb.go @@ -179,6 +179,50 @@ func (m *ReadAfterWrite) CloneMessageVT() proto.Message { return m.CloneVT() } +func (m *ExecuteMultiRequest) CloneVT() *ExecuteMultiRequest { + if m == nil { + return (*ExecuteMultiRequest)(nil) + } + r := new(ExecuteMultiRequest) + r.CallerId = m.CallerId.CloneVT() + r.Sql = m.Sql + r.Session = m.Session.CloneVT() + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *ExecuteMultiRequest) CloneMessageVT() proto.Message { + return m.CloneVT() +} + +func (m *ExecuteMultiResponse) CloneVT() *ExecuteMultiResponse { + if m == nil { + return (*ExecuteMultiResponse)(nil) + } + r := new(ExecuteMultiResponse) + r.Error = m.Error.CloneVT() + r.Session = m.Session.CloneVT() + if rhs := m.Results; rhs != nil { + tmpContainer := make([]*query.QueryResult, len(rhs)) + for k, v := range rhs { + tmpContainer[k] = v.CloneVT() + } + r.Results = tmpContainer + } + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *ExecuteMultiResponse) CloneMessageVT() proto.Message { + return m.CloneVT() +} + func (m *ExecuteRequest) CloneVT() *ExecuteRequest { if m == nil { return (*ExecuteRequest)(nil) @@ -305,6 +349,45 @@ func (m *StreamExecuteResponse) CloneMessageVT() proto.Message { return m.CloneVT() } +func (m *StreamExecuteMultiRequest) CloneVT() *StreamExecuteMultiRequest { + if m == nil { + return (*StreamExecuteMultiRequest)(nil) + } + r := new(StreamExecuteMultiRequest) + r.CallerId = m.CallerId.CloneVT() + r.Sql = m.Sql + r.Session = m.Session.CloneVT() + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *StreamExecuteMultiRequest) CloneMessageVT() proto.Message { + return m.CloneVT() +} + +func (m *StreamExecuteMultiResponse) CloneVT() *StreamExecuteMultiResponse { + if m == nil { + return (*StreamExecuteMultiResponse)(nil) + } + r := new(StreamExecuteMultiResponse) + r.Result = m.Result.CloneVT() + r.MoreResults = m.MoreResults + r.NewResult = m.NewResult + r.Session = m.Session.CloneVT() + if len(m.unknownFields) > 0 { + r.unknownFields = make([]byte, len(m.unknownFields)) + copy(r.unknownFields, m.unknownFields) + } + return r +} + +func (m *StreamExecuteMultiResponse) CloneMessageVT() proto.Message { + return m.CloneVT() +} + func (m *ResolveTransactionRequest) CloneVT() *ResolveTransactionRequest { if m == nil { return (*ResolveTransactionRequest)(nil) @@ -1005,6 +1088,131 @@ func (m *ReadAfterWrite) MarshalToSizedBufferVT(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *ExecuteMultiRequest) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ExecuteMultiRequest) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *ExecuteMultiRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if m.Session != nil { + size, err := m.Session.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x1a + } + if len(m.Sql) > 0 { + i -= len(m.Sql) + copy(dAtA[i:], m.Sql) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Sql))) + i-- + dAtA[i] = 0x12 + } + if m.CallerId != nil { + size, err := m.CallerId.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *ExecuteMultiResponse) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ExecuteMultiResponse) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *ExecuteMultiResponse) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if len(m.Results) > 0 { + for iNdEx := len(m.Results) - 1; iNdEx >= 0; iNdEx-- { + size, err := m.Results[iNdEx].MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x1a + } + } + if m.Session != nil { + size, err := m.Session.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x12 + } + if m.Error != nil { + size, err := m.Error.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + func (m *ExecuteRequest) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil @@ -1387,7 +1595,7 @@ func (m *StreamExecuteResponse) MarshalToSizedBufferVT(dAtA []byte) (int, error) return len(dAtA) - i, nil } -func (m *ResolveTransactionRequest) MarshalVT() (dAtA []byte, err error) { +func (m *StreamExecuteMultiRequest) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil } @@ -1400,12 +1608,12 @@ func (m *ResolveTransactionRequest) MarshalVT() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *ResolveTransactionRequest) MarshalToVT(dAtA []byte) (int, error) { +func (m *StreamExecuteMultiRequest) MarshalToVT(dAtA []byte) (int, error) { size := m.SizeVT() return m.MarshalToSizedBufferVT(dAtA[:size]) } -func (m *ResolveTransactionRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { +func (m *StreamExecuteMultiRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { return 0, nil } @@ -1417,10 +1625,20 @@ func (m *ResolveTransactionRequest) MarshalToSizedBufferVT(dAtA []byte) (int, er i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } - if len(m.Dtid) > 0 { - i -= len(m.Dtid) - copy(dAtA[i:], m.Dtid) - i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Dtid))) + if m.Session != nil { + size, err := m.Session.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x1a + } + if len(m.Sql) > 0 { + i -= len(m.Sql) + copy(dAtA[i:], m.Sql) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Sql))) i-- dAtA[i] = 0x12 } @@ -1437,7 +1655,7 @@ func (m *ResolveTransactionRequest) MarshalToSizedBufferVT(dAtA []byte) (int, er return len(dAtA) - i, nil } -func (m *ResolveTransactionResponse) MarshalVT() (dAtA []byte, err error) { +func (m *StreamExecuteMultiResponse) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil } @@ -1450,12 +1668,12 @@ func (m *ResolveTransactionResponse) MarshalVT() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *ResolveTransactionResponse) MarshalToVT(dAtA []byte) (int, error) { +func (m *StreamExecuteMultiResponse) MarshalToVT(dAtA []byte) (int, error) { size := m.SizeVT() return m.MarshalToSizedBufferVT(dAtA[:size]) } -func (m *ResolveTransactionResponse) MarshalToSizedBufferVT(dAtA []byte) (int, error) { +func (m *StreamExecuteMultiResponse) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { return 0, nil } @@ -1467,10 +1685,50 @@ func (m *ResolveTransactionResponse) MarshalToSizedBufferVT(dAtA []byte) (int, e i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } + if m.Session != nil { + size, err := m.Session.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0x22 + } + if m.NewResult { + i-- + if m.NewResult { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x18 + } + if m.MoreResults { + i-- + if m.MoreResults { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x10 + } + if m.Result != nil { + size, err := m.Result.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0xa + } return len(dAtA) - i, nil } -func (m *VStreamFlags) MarshalVT() (dAtA []byte, err error) { +func (m *ResolveTransactionRequest) MarshalVT() (dAtA []byte, err error) { if m == nil { return nil, nil } @@ -1483,12 +1741,12 @@ func (m *VStreamFlags) MarshalVT() (dAtA []byte, err error) { return dAtA[:n], nil } -func (m *VStreamFlags) MarshalToVT(dAtA []byte) (int, error) { +func (m *ResolveTransactionRequest) MarshalToVT(dAtA []byte) (int, error) { size := m.SizeVT() return m.MarshalToSizedBufferVT(dAtA[:size]) } -func (m *VStreamFlags) MarshalToSizedBufferVT(dAtA []byte) (int, error) { +func (m *ResolveTransactionRequest) MarshalToSizedBufferVT(dAtA []byte) (int, error) { if m == nil { return 0, nil } @@ -1500,12 +1758,95 @@ func (m *VStreamFlags) MarshalToSizedBufferVT(dAtA []byte) (int, error) { i -= len(m.unknownFields) copy(dAtA[i:], m.unknownFields) } - if m.IncludeReshardJournalEvents { + if len(m.Dtid) > 0 { + i -= len(m.Dtid) + copy(dAtA[i:], m.Dtid) + i = protohelpers.EncodeVarint(dAtA, i, uint64(len(m.Dtid))) i-- - if m.IncludeReshardJournalEvents { - dAtA[i] = 1 - } else { - dAtA[i] = 0 + dAtA[i] = 0x12 + } + if m.CallerId != nil { + size, err := m.CallerId.MarshalToSizedBufferVT(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = protohelpers.EncodeVarint(dAtA, i, uint64(size)) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + +func (m *ResolveTransactionResponse) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *ResolveTransactionResponse) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *ResolveTransactionResponse) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + return len(dAtA) - i, nil +} + +func (m *VStreamFlags) MarshalVT() (dAtA []byte, err error) { + if m == nil { + return nil, nil + } + size := m.SizeVT() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBufferVT(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *VStreamFlags) MarshalToVT(dAtA []byte) (int, error) { + size := m.SizeVT() + return m.MarshalToSizedBufferVT(dAtA[:size]) +} + +func (m *VStreamFlags) MarshalToSizedBufferVT(dAtA []byte) (int, error) { + if m == nil { + return 0, nil + } + i := len(dAtA) + _ = i + var l int + _ = l + if m.unknownFields != nil { + i -= len(m.unknownFields) + copy(dAtA[i:], m.unknownFields) + } + if m.IncludeReshardJournalEvents { + i-- + if m.IncludeReshardJournalEvents { + dAtA[i] = 1 + } else { + dAtA[i] = 0 } i-- dAtA[i] = 0x40 @@ -2131,6 +2472,52 @@ func (m *ReadAfterWrite) SizeVT() (n int) { return n } +func (m *ExecuteMultiRequest) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.CallerId != nil { + l = m.CallerId.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + l = len(m.Sql) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Session != nil { + l = m.Session.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *ExecuteMultiResponse) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Error != nil { + l = m.Error.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Session != nil { + l = m.Session.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if len(m.Results) > 0 { + for _, e := range m.Results { + l = e.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + } + n += len(m.unknownFields) + return n +} + func (m *ExecuteRequest) SizeVT() (n int) { if m == nil { return 0 @@ -2266,6 +2653,52 @@ func (m *StreamExecuteResponse) SizeVT() (n int) { return n } +func (m *StreamExecuteMultiRequest) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.CallerId != nil { + l = m.CallerId.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + l = len(m.Sql) + if l > 0 { + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.Session != nil { + l = m.Session.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + +func (m *StreamExecuteMultiResponse) SizeVT() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Result != nil { + l = m.Result.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + if m.MoreResults { + n += 2 + } + if m.NewResult { + n += 2 + } + if m.Session != nil { + l = m.Session.SizeVT() + n += 1 + l + protohelpers.SizeOfVarint(uint64(l)) + } + n += len(m.unknownFields) + return n +} + func (m *ResolveTransactionRequest) SizeVT() (n int) { if m == nil { return 0 @@ -4041,7 +4474,7 @@ func (m *ReadAfterWrite) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *ExecuteRequest) UnmarshalVT(dAtA []byte) error { +func (m *ExecuteMultiRequest) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -4064,10 +4497,10 @@ func (m *ExecuteRequest) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: ExecuteRequest: wiretype end group for non-group") + return fmt.Errorf("proto: ExecuteMultiRequest: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: ExecuteRequest: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: ExecuteMultiRequest: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -4108,9 +4541,9 @@ func (m *ExecuteRequest) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex case 2: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Sql", wireType) } - var msglen int + var stringLen uint64 for shift := uint(0); ; shift += 7 { if shift >= 64 { return protohelpers.ErrIntOverflow @@ -4120,31 +4553,27 @@ func (m *ExecuteRequest) UnmarshalVT(dAtA []byte) error { } b := dAtA[iNdEx] iNdEx++ - msglen |= int(b&0x7F) << shift + stringLen |= uint64(b&0x7F) << shift if b < 0x80 { break } } - if msglen < 0 { + intStringLen := int(stringLen) + if intStringLen < 0 { return protohelpers.ErrInvalidLength } - postIndex := iNdEx + msglen + postIndex := iNdEx + intStringLen if postIndex < 0 { return protohelpers.ErrInvalidLength } if postIndex > l { return io.ErrUnexpectedEOF } - if m.Session == nil { - m.Session = &Session{} - } - if err := m.Session.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { - return err - } + m.Sql = string(dAtA[iNdEx:postIndex]) iNdEx = postIndex case 3: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Query", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -4171,33 +4600,13 @@ func (m *ExecuteRequest) UnmarshalVT(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - if m.Query == nil { - m.Query = &query.BoundQuery{} + if m.Session == nil { + m.Session = &Session{} } - if err := m.Query.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + if err := m.Session.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex - case 8: - if wireType != 0 { - return fmt.Errorf("proto: wrong wireType = %d for field Prepared", wireType) - } - var v int - for shift := uint(0); ; shift += 7 { - if shift >= 64 { - return protohelpers.ErrIntOverflow - } - if iNdEx >= l { - return io.ErrUnexpectedEOF - } - b := dAtA[iNdEx] - iNdEx++ - v |= int(b&0x7F) << shift - if b < 0x80 { - break - } - } - m.Prepared = bool(v != 0) default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) @@ -4220,7 +4629,7 @@ func (m *ExecuteRequest) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *ExecuteResponse) UnmarshalVT(dAtA []byte) error { +func (m *ExecuteMultiResponse) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -4243,10 +4652,10 @@ func (m *ExecuteResponse) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: ExecuteResponse: wiretype end group for non-group") + return fmt.Errorf("proto: ExecuteMultiResponse: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: ExecuteResponse: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: ExecuteMultiResponse: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -4323,7 +4732,7 @@ func (m *ExecuteResponse) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex case 3: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Result", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Results", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -4350,10 +4759,8 @@ func (m *ExecuteResponse) UnmarshalVT(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - if m.Result == nil { - m.Result = &query.QueryResult{} - } - if err := m.Result.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + m.Results = append(m.Results, &query.QueryResult{}) + if err := m.Results[len(m.Results)-1].UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -4379,7 +4786,7 @@ func (m *ExecuteResponse) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *ExecuteBatchRequest) UnmarshalVT(dAtA []byte) error { +func (m *ExecuteRequest) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -4402,10 +4809,10 @@ func (m *ExecuteBatchRequest) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: ExecuteBatchRequest: wiretype end group for non-group") + return fmt.Errorf("proto: ExecuteRequest: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: ExecuteBatchRequest: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: ExecuteRequest: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -4482,7 +4889,7 @@ func (m *ExecuteBatchRequest) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex case 3: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Queries", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Query", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -4509,11 +4916,33 @@ func (m *ExecuteBatchRequest) UnmarshalVT(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Queries = append(m.Queries, &query.BoundQuery{}) - if err := m.Queries[len(m.Queries)-1].UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + if m.Query == nil { + m.Query = &query.BoundQuery{} + } + if err := m.Query.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex + case 8: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Prepared", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.Prepared = bool(v != 0) default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) @@ -4536,7 +4965,7 @@ func (m *ExecuteBatchRequest) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *ExecuteBatchResponse) UnmarshalVT(dAtA []byte) error { +func (m *ExecuteResponse) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -4559,10 +4988,10 @@ func (m *ExecuteBatchResponse) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: ExecuteBatchResponse: wiretype end group for non-group") + return fmt.Errorf("proto: ExecuteResponse: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: ExecuteBatchResponse: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: ExecuteResponse: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -4639,7 +5068,7 @@ func (m *ExecuteBatchResponse) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex case 3: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Results", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Result", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -4666,8 +5095,10 @@ func (m *ExecuteBatchResponse) UnmarshalVT(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - m.Results = append(m.Results, &query.ResultWithError{}) - if err := m.Results[len(m.Results)-1].UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + if m.Result == nil { + m.Result = &query.QueryResult{} + } + if err := m.Result.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex @@ -4693,7 +5124,7 @@ func (m *ExecuteBatchResponse) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *StreamExecuteRequest) UnmarshalVT(dAtA []byte) error { +func (m *ExecuteBatchRequest) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -4716,10 +5147,10 @@ func (m *StreamExecuteRequest) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: StreamExecuteRequest: wiretype end group for non-group") + return fmt.Errorf("proto: ExecuteBatchRequest: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: StreamExecuteRequest: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: ExecuteBatchRequest: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: @@ -4760,7 +5191,7 @@ func (m *StreamExecuteRequest) UnmarshalVT(dAtA []byte) error { iNdEx = postIndex case 2: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Query", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -4787,14 +5218,135 @@ func (m *StreamExecuteRequest) UnmarshalVT(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - if m.Query == nil { - m.Query = &query.BoundQuery{} + if m.Session == nil { + m.Session = &Session{} } - if err := m.Query.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + if err := m.Session.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex - case 6: + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Queries", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Queries = append(m.Queries, &query.BoundQuery{}) + if err := m.Queries[len(m.Queries)-1].UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *ExecuteBatchResponse) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: ExecuteBatchResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: ExecuteBatchResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Error == nil { + m.Error = &vtrpc.RPCError{} + } + if err := m.Error.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) } @@ -4830,6 +5382,40 @@ func (m *StreamExecuteRequest) UnmarshalVT(dAtA []byte) error { return err } iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Results", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Results = append(m.Results, &query.ResultWithError{}) + if err := m.Results[len(m.Results)-1].UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex default: iNdEx = preIndex skippy, err := protohelpers.Skip(dAtA[iNdEx:]) @@ -4852,7 +5438,7 @@ func (m *StreamExecuteRequest) UnmarshalVT(dAtA []byte) error { } return nil } -func (m *StreamExecuteResponse) UnmarshalVT(dAtA []byte) error { +func (m *StreamExecuteRequest) UnmarshalVT(dAtA []byte) error { l := len(dAtA) iNdEx := 0 for iNdEx < l { @@ -4875,15 +5461,15 @@ func (m *StreamExecuteResponse) UnmarshalVT(dAtA []byte) error { fieldNum := int32(wire >> 3) wireType := int(wire & 0x7) if wireType == 4 { - return fmt.Errorf("proto: StreamExecuteResponse: wiretype end group for non-group") + return fmt.Errorf("proto: StreamExecuteRequest: wiretype end group for non-group") } if fieldNum <= 0 { - return fmt.Errorf("proto: StreamExecuteResponse: illegal tag %d (wire type %d)", fieldNum, wire) + return fmt.Errorf("proto: StreamExecuteRequest: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { case 1: if wireType != 2 { - return fmt.Errorf("proto: wrong wireType = %d for field Result", wireType) + return fmt.Errorf("proto: wrong wireType = %d for field CallerId", wireType) } var msglen int for shift := uint(0); ; shift += 7 { @@ -4910,14 +5496,491 @@ func (m *StreamExecuteResponse) UnmarshalVT(dAtA []byte) error { if postIndex > l { return io.ErrUnexpectedEOF } - if m.Result == nil { - m.Result = &query.QueryResult{} + if m.CallerId == nil { + m.CallerId = &vtrpc.CallerID{} } - if err := m.Result.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + if err := m.CallerId.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { return err } iNdEx = postIndex case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Query", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Query == nil { + m.Query = &query.BoundQuery{} + } + if err := m.Query.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 6: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Session == nil { + m.Session = &Session{} + } + if err := m.Session.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *StreamExecuteResponse) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: StreamExecuteResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: StreamExecuteResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Result", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Result == nil { + m.Result = &query.QueryResult{} + } + if err := m.Result.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Session == nil { + m.Session = &Session{} + } + if err := m.Session.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *StreamExecuteMultiRequest) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: StreamExecuteMultiRequest: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: StreamExecuteMultiRequest: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field CallerId", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.CallerId == nil { + m.CallerId = &vtrpc.CallerID{} + } + if err := m.CallerId.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Sql", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.Sql = string(dAtA[iNdEx:postIndex]) + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Session == nil { + m.Session = &Session{} + } + if err := m.Session.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := protohelpers.Skip(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return protohelpers.ErrInvalidLength + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.unknownFields = append(m.unknownFields, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} +func (m *StreamExecuteMultiResponse) UnmarshalVT(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: StreamExecuteMultiResponse: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: StreamExecuteMultiResponse: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Result", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return protohelpers.ErrInvalidLength + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return protohelpers.ErrInvalidLength + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Result == nil { + m.Result = &query.ResultWithError{} + } + if err := m.Result.UnmarshalVT(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex + case 2: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field MoreResults", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.MoreResults = bool(v != 0) + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field NewResult", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return protohelpers.ErrIntOverflow + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + m.NewResult = bool(v != 0) + case 4: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field Session", wireType) } diff --git a/go/vt/proto/vtgateservice/vtgateservice.pb.go b/go/vt/proto/vtgateservice/vtgateservice.pb.go index 3df5bc80722..2cf68361057 100644 --- a/go/vt/proto/vtgateservice/vtgateservice.pb.go +++ b/go/vt/proto/vtgateservice/vtgateservice.pb.go @@ -45,70 +45,88 @@ var file_vtgateservice_proto_rawDesc = string([]byte{ 0x0a, 0x13, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0d, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x1a, 0x0c, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x32, 0xb0, 0x03, 0x0a, 0x06, 0x56, 0x69, 0x74, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, + 0x74, 0x6f, 0x32, 0xde, 0x04, 0x0a, 0x06, 0x56, 0x69, 0x74, 0x65, 0x73, 0x73, 0x12, 0x3c, 0x0a, 0x07, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x12, 0x16, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x45, - 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x1b, 0x2e, 0x76, 0x74, - 0x67, 0x61, 0x74, 0x65, 0x2e, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x42, 0x61, 0x74, 0x63, - 0x68, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, + 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x12, 0x1b, 0x2e, 0x76, 0x74, + 0x67, 0x61, 0x74, 0x65, 0x2e, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x4d, 0x75, 0x6c, 0x74, + 0x69, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, + 0x65, 0x2e, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x45, 0x78, 0x65, 0x63, + 0x75, 0x74, 0x65, 0x42, 0x61, 0x74, 0x63, 0x68, 0x12, 0x1b, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x50, 0x0a, 0x0d, 0x53, 0x74, 0x72, 0x65, - 0x61, 0x6d, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x12, 0x1c, 0x2e, 0x76, 0x74, 0x67, 0x61, - 0x74, 0x65, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, - 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3e, 0x0a, 0x07, 0x56, 0x53, - 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x16, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x56, - 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, - 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x56, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3c, 0x0a, 0x07, 0x50, 0x72, - 0x65, 0x70, 0x61, 0x72, 0x65, 0x12, 0x16, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x50, - 0x72, 0x65, 0x70, 0x61, 0x72, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, - 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x50, 0x72, 0x65, 0x70, 0x61, 0x72, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x43, 0x6c, 0x6f, 0x73, - 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1b, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, - 0x65, 0x2e, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x43, - 0x6c, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x42, 0x0a, 0x14, 0x69, 0x6f, 0x2e, 0x76, 0x69, 0x74, 0x65, - 0x73, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x5a, 0x2a, 0x76, - 0x69, 0x74, 0x65, 0x73, 0x73, 0x2e, 0x69, 0x6f, 0x2f, 0x76, 0x69, 0x74, 0x65, 0x73, 0x73, 0x2f, - 0x67, 0x6f, 0x2f, 0x76, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x76, 0x74, 0x67, 0x61, - 0x74, 0x65, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x45, + 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x42, 0x61, 0x74, 0x63, 0x68, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x50, 0x0a, 0x0d, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x45, + 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x12, 0x1c, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, + 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x5f, 0x0a, 0x12, 0x53, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x12, 0x21, 0x2e, + 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x45, 0x78, 0x65, + 0x63, 0x75, 0x74, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x22, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x45, 0x78, 0x65, 0x63, 0x75, 0x74, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3e, 0x0a, 0x07, 0x56, 0x53, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x12, 0x16, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x56, 0x53, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x76, 0x74, + 0x67, 0x61, 0x74, 0x65, 0x2e, 0x56, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3c, 0x0a, 0x07, 0x50, 0x72, 0x65, 0x70, + 0x61, 0x72, 0x65, 0x12, 0x16, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x50, 0x72, 0x65, + 0x70, 0x61, 0x72, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x17, 0x2e, 0x76, 0x74, + 0x67, 0x61, 0x74, 0x65, 0x2e, 0x50, 0x72, 0x65, 0x70, 0x61, 0x72, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x53, + 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1b, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, + 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, 0x2e, 0x43, 0x6c, 0x6f, + 0x73, 0x65, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x42, 0x42, 0x0a, 0x14, 0x69, 0x6f, 0x2e, 0x76, 0x69, 0x74, 0x65, 0x73, 0x73, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x67, 0x72, 0x70, 0x63, 0x5a, 0x2a, 0x76, 0x69, 0x74, + 0x65, 0x73, 0x73, 0x2e, 0x69, 0x6f, 0x2f, 0x76, 0x69, 0x74, 0x65, 0x73, 0x73, 0x2f, 0x67, 0x6f, + 0x2f, 0x76, 0x74, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x76, 0x74, 0x67, 0x61, 0x74, 0x65, + 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, }) var file_vtgateservice_proto_goTypes = []any{ - (*vtgate.ExecuteRequest)(nil), // 0: vtgate.ExecuteRequest - (*vtgate.ExecuteBatchRequest)(nil), // 1: vtgate.ExecuteBatchRequest - (*vtgate.StreamExecuteRequest)(nil), // 2: vtgate.StreamExecuteRequest - (*vtgate.VStreamRequest)(nil), // 3: vtgate.VStreamRequest - (*vtgate.PrepareRequest)(nil), // 4: vtgate.PrepareRequest - (*vtgate.CloseSessionRequest)(nil), // 5: vtgate.CloseSessionRequest - (*vtgate.ExecuteResponse)(nil), // 6: vtgate.ExecuteResponse - (*vtgate.ExecuteBatchResponse)(nil), // 7: vtgate.ExecuteBatchResponse - (*vtgate.StreamExecuteResponse)(nil), // 8: vtgate.StreamExecuteResponse - (*vtgate.VStreamResponse)(nil), // 9: vtgate.VStreamResponse - (*vtgate.PrepareResponse)(nil), // 10: vtgate.PrepareResponse - (*vtgate.CloseSessionResponse)(nil), // 11: vtgate.CloseSessionResponse + (*vtgate.ExecuteRequest)(nil), // 0: vtgate.ExecuteRequest + (*vtgate.ExecuteMultiRequest)(nil), // 1: vtgate.ExecuteMultiRequest + (*vtgate.ExecuteBatchRequest)(nil), // 2: vtgate.ExecuteBatchRequest + (*vtgate.StreamExecuteRequest)(nil), // 3: vtgate.StreamExecuteRequest + (*vtgate.StreamExecuteMultiRequest)(nil), // 4: vtgate.StreamExecuteMultiRequest + (*vtgate.VStreamRequest)(nil), // 5: vtgate.VStreamRequest + (*vtgate.PrepareRequest)(nil), // 6: vtgate.PrepareRequest + (*vtgate.CloseSessionRequest)(nil), // 7: vtgate.CloseSessionRequest + (*vtgate.ExecuteResponse)(nil), // 8: vtgate.ExecuteResponse + (*vtgate.ExecuteMultiResponse)(nil), // 9: vtgate.ExecuteMultiResponse + (*vtgate.ExecuteBatchResponse)(nil), // 10: vtgate.ExecuteBatchResponse + (*vtgate.StreamExecuteResponse)(nil), // 11: vtgate.StreamExecuteResponse + (*vtgate.StreamExecuteMultiResponse)(nil), // 12: vtgate.StreamExecuteMultiResponse + (*vtgate.VStreamResponse)(nil), // 13: vtgate.VStreamResponse + (*vtgate.PrepareResponse)(nil), // 14: vtgate.PrepareResponse + (*vtgate.CloseSessionResponse)(nil), // 15: vtgate.CloseSessionResponse } var file_vtgateservice_proto_depIdxs = []int32{ 0, // 0: vtgateservice.Vitess.Execute:input_type -> vtgate.ExecuteRequest - 1, // 1: vtgateservice.Vitess.ExecuteBatch:input_type -> vtgate.ExecuteBatchRequest - 2, // 2: vtgateservice.Vitess.StreamExecute:input_type -> vtgate.StreamExecuteRequest - 3, // 3: vtgateservice.Vitess.VStream:input_type -> vtgate.VStreamRequest - 4, // 4: vtgateservice.Vitess.Prepare:input_type -> vtgate.PrepareRequest - 5, // 5: vtgateservice.Vitess.CloseSession:input_type -> vtgate.CloseSessionRequest - 6, // 6: vtgateservice.Vitess.Execute:output_type -> vtgate.ExecuteResponse - 7, // 7: vtgateservice.Vitess.ExecuteBatch:output_type -> vtgate.ExecuteBatchResponse - 8, // 8: vtgateservice.Vitess.StreamExecute:output_type -> vtgate.StreamExecuteResponse - 9, // 9: vtgateservice.Vitess.VStream:output_type -> vtgate.VStreamResponse - 10, // 10: vtgateservice.Vitess.Prepare:output_type -> vtgate.PrepareResponse - 11, // 11: vtgateservice.Vitess.CloseSession:output_type -> vtgate.CloseSessionResponse - 6, // [6:12] is the sub-list for method output_type - 0, // [0:6] is the sub-list for method input_type + 1, // 1: vtgateservice.Vitess.ExecuteMulti:input_type -> vtgate.ExecuteMultiRequest + 2, // 2: vtgateservice.Vitess.ExecuteBatch:input_type -> vtgate.ExecuteBatchRequest + 3, // 3: vtgateservice.Vitess.StreamExecute:input_type -> vtgate.StreamExecuteRequest + 4, // 4: vtgateservice.Vitess.StreamExecuteMulti:input_type -> vtgate.StreamExecuteMultiRequest + 5, // 5: vtgateservice.Vitess.VStream:input_type -> vtgate.VStreamRequest + 6, // 6: vtgateservice.Vitess.Prepare:input_type -> vtgate.PrepareRequest + 7, // 7: vtgateservice.Vitess.CloseSession:input_type -> vtgate.CloseSessionRequest + 8, // 8: vtgateservice.Vitess.Execute:output_type -> vtgate.ExecuteResponse + 9, // 9: vtgateservice.Vitess.ExecuteMulti:output_type -> vtgate.ExecuteMultiResponse + 10, // 10: vtgateservice.Vitess.ExecuteBatch:output_type -> vtgate.ExecuteBatchResponse + 11, // 11: vtgateservice.Vitess.StreamExecute:output_type -> vtgate.StreamExecuteResponse + 12, // 12: vtgateservice.Vitess.StreamExecuteMulti:output_type -> vtgate.StreamExecuteMultiResponse + 13, // 13: vtgateservice.Vitess.VStream:output_type -> vtgate.VStreamResponse + 14, // 14: vtgateservice.Vitess.Prepare:output_type -> vtgate.PrepareResponse + 15, // 15: vtgateservice.Vitess.CloseSession:output_type -> vtgate.CloseSessionResponse + 8, // [8:16] is the sub-list for method output_type + 0, // [0:8] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name diff --git a/go/vt/proto/vtgateservice/vtgateservice_grpc.pb.go b/go/vt/proto/vtgateservice/vtgateservice_grpc.pb.go index 80042781649..4cc67bf42c6 100644 --- a/go/vt/proto/vtgateservice/vtgateservice_grpc.pb.go +++ b/go/vt/proto/vtgateservice/vtgateservice_grpc.pb.go @@ -28,6 +28,8 @@ type VitessClient interface { // information in conjunction with the vindexes to route the query. // API group: v3 Execute(ctx context.Context, in *vtgate.ExecuteRequest, opts ...grpc.CallOption) (*vtgate.ExecuteResponse, error) + // ExecuteMulti executes multiple queries on the right shards. + ExecuteMulti(ctx context.Context, in *vtgate.ExecuteMultiRequest, opts ...grpc.CallOption) (*vtgate.ExecuteMultiResponse, error) // ExecuteBatch tries to route the list of queries on the right shards. // It depends on the query and bind variables to provide enough // information in conjunction with the vindexes to route the query. @@ -39,6 +41,8 @@ type VitessClient interface { // Use this method if the query returns a large number of rows. // API group: v3 StreamExecute(ctx context.Context, in *vtgate.StreamExecuteRequest, opts ...grpc.CallOption) (Vitess_StreamExecuteClient, error) + // StreamExecuteMulti executes multiple streaming queries. + StreamExecuteMulti(ctx context.Context, in *vtgate.StreamExecuteMultiRequest, opts ...grpc.CallOption) (Vitess_StreamExecuteMultiClient, error) // VStream streams binlog events from the requested sources. VStream(ctx context.Context, in *vtgate.VStreamRequest, opts ...grpc.CallOption) (Vitess_VStreamClient, error) // Prepare is used by the MySQL server plugin as part of supporting prepared statements. @@ -66,6 +70,15 @@ func (c *vitessClient) Execute(ctx context.Context, in *vtgate.ExecuteRequest, o return out, nil } +func (c *vitessClient) ExecuteMulti(ctx context.Context, in *vtgate.ExecuteMultiRequest, opts ...grpc.CallOption) (*vtgate.ExecuteMultiResponse, error) { + out := new(vtgate.ExecuteMultiResponse) + err := c.cc.Invoke(ctx, "/vtgateservice.Vitess/ExecuteMulti", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *vitessClient) ExecuteBatch(ctx context.Context, in *vtgate.ExecuteBatchRequest, opts ...grpc.CallOption) (*vtgate.ExecuteBatchResponse, error) { out := new(vtgate.ExecuteBatchResponse) err := c.cc.Invoke(ctx, "/vtgateservice.Vitess/ExecuteBatch", in, out, opts...) @@ -107,8 +120,40 @@ func (x *vitessStreamExecuteClient) Recv() (*vtgate.StreamExecuteResponse, error return m, nil } +func (c *vitessClient) StreamExecuteMulti(ctx context.Context, in *vtgate.StreamExecuteMultiRequest, opts ...grpc.CallOption) (Vitess_StreamExecuteMultiClient, error) { + stream, err := c.cc.NewStream(ctx, &Vitess_ServiceDesc.Streams[1], "/vtgateservice.Vitess/StreamExecuteMulti", opts...) + if err != nil { + return nil, err + } + x := &vitessStreamExecuteMultiClient{stream} + if err := x.ClientStream.SendMsg(in); err != nil { + return nil, err + } + if err := x.ClientStream.CloseSend(); err != nil { + return nil, err + } + return x, nil +} + +type Vitess_StreamExecuteMultiClient interface { + Recv() (*vtgate.StreamExecuteMultiResponse, error) + grpc.ClientStream +} + +type vitessStreamExecuteMultiClient struct { + grpc.ClientStream +} + +func (x *vitessStreamExecuteMultiClient) Recv() (*vtgate.StreamExecuteMultiResponse, error) { + m := new(vtgate.StreamExecuteMultiResponse) + if err := x.ClientStream.RecvMsg(m); err != nil { + return nil, err + } + return m, nil +} + func (c *vitessClient) VStream(ctx context.Context, in *vtgate.VStreamRequest, opts ...grpc.CallOption) (Vitess_VStreamClient, error) { - stream, err := c.cc.NewStream(ctx, &Vitess_ServiceDesc.Streams[1], "/vtgateservice.Vitess/VStream", opts...) + stream, err := c.cc.NewStream(ctx, &Vitess_ServiceDesc.Streams[2], "/vtgateservice.Vitess/VStream", opts...) if err != nil { return nil, err } @@ -166,6 +211,8 @@ type VitessServer interface { // information in conjunction with the vindexes to route the query. // API group: v3 Execute(context.Context, *vtgate.ExecuteRequest) (*vtgate.ExecuteResponse, error) + // ExecuteMulti executes multiple queries on the right shards. + ExecuteMulti(context.Context, *vtgate.ExecuteMultiRequest) (*vtgate.ExecuteMultiResponse, error) // ExecuteBatch tries to route the list of queries on the right shards. // It depends on the query and bind variables to provide enough // information in conjunction with the vindexes to route the query. @@ -177,6 +224,8 @@ type VitessServer interface { // Use this method if the query returns a large number of rows. // API group: v3 StreamExecute(*vtgate.StreamExecuteRequest, Vitess_StreamExecuteServer) error + // StreamExecuteMulti executes multiple streaming queries. + StreamExecuteMulti(*vtgate.StreamExecuteMultiRequest, Vitess_StreamExecuteMultiServer) error // VStream streams binlog events from the requested sources. VStream(*vtgate.VStreamRequest, Vitess_VStreamServer) error // Prepare is used by the MySQL server plugin as part of supporting prepared statements. @@ -195,12 +244,18 @@ type UnimplementedVitessServer struct { func (UnimplementedVitessServer) Execute(context.Context, *vtgate.ExecuteRequest) (*vtgate.ExecuteResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Execute not implemented") } +func (UnimplementedVitessServer) ExecuteMulti(context.Context, *vtgate.ExecuteMultiRequest) (*vtgate.ExecuteMultiResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ExecuteMulti not implemented") +} func (UnimplementedVitessServer) ExecuteBatch(context.Context, *vtgate.ExecuteBatchRequest) (*vtgate.ExecuteBatchResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method ExecuteBatch not implemented") } func (UnimplementedVitessServer) StreamExecute(*vtgate.StreamExecuteRequest, Vitess_StreamExecuteServer) error { return status.Errorf(codes.Unimplemented, "method StreamExecute not implemented") } +func (UnimplementedVitessServer) StreamExecuteMulti(*vtgate.StreamExecuteMultiRequest, Vitess_StreamExecuteMultiServer) error { + return status.Errorf(codes.Unimplemented, "method StreamExecuteMulti not implemented") +} func (UnimplementedVitessServer) VStream(*vtgate.VStreamRequest, Vitess_VStreamServer) error { return status.Errorf(codes.Unimplemented, "method VStream not implemented") } @@ -241,6 +296,24 @@ func _Vitess_Execute_Handler(srv interface{}, ctx context.Context, dec func(inte return interceptor(ctx, in, info, handler) } +func _Vitess_ExecuteMulti_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(vtgate.ExecuteMultiRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(VitessServer).ExecuteMulti(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/vtgateservice.Vitess/ExecuteMulti", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(VitessServer).ExecuteMulti(ctx, req.(*vtgate.ExecuteMultiRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _Vitess_ExecuteBatch_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(vtgate.ExecuteBatchRequest) if err := dec(in); err != nil { @@ -280,6 +353,27 @@ func (x *vitessStreamExecuteServer) Send(m *vtgate.StreamExecuteResponse) error return x.ServerStream.SendMsg(m) } +func _Vitess_StreamExecuteMulti_Handler(srv interface{}, stream grpc.ServerStream) error { + m := new(vtgate.StreamExecuteMultiRequest) + if err := stream.RecvMsg(m); err != nil { + return err + } + return srv.(VitessServer).StreamExecuteMulti(m, &vitessStreamExecuteMultiServer{stream}) +} + +type Vitess_StreamExecuteMultiServer interface { + Send(*vtgate.StreamExecuteMultiResponse) error + grpc.ServerStream +} + +type vitessStreamExecuteMultiServer struct { + grpc.ServerStream +} + +func (x *vitessStreamExecuteMultiServer) Send(m *vtgate.StreamExecuteMultiResponse) error { + return x.ServerStream.SendMsg(m) +} + func _Vitess_VStream_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(vtgate.VStreamRequest) if err := stream.RecvMsg(m); err != nil { @@ -348,6 +442,10 @@ var Vitess_ServiceDesc = grpc.ServiceDesc{ MethodName: "Execute", Handler: _Vitess_Execute_Handler, }, + { + MethodName: "ExecuteMulti", + Handler: _Vitess_ExecuteMulti_Handler, + }, { MethodName: "ExecuteBatch", Handler: _Vitess_ExecuteBatch_Handler, @@ -367,6 +465,11 @@ var Vitess_ServiceDesc = grpc.ServiceDesc{ Handler: _Vitess_StreamExecute_Handler, ServerStreams: true, }, + { + StreamName: "StreamExecuteMulti", + Handler: _Vitess_StreamExecuteMulti_Handler, + ServerStreams: true, + }, { StreamName: "VStream", Handler: _Vitess_VStream_Handler, diff --git a/go/vt/vitessdriver/fakeserver_test.go b/go/vt/vitessdriver/fakeserver_test.go index a2b43caefde..f914b280a1b 100644 --- a/go/vt/vitessdriver/fakeserver_test.go +++ b/go/vt/vitessdriver/fakeserver_test.go @@ -28,6 +28,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/vtgateservice" ) @@ -177,6 +178,43 @@ func (f *fakeVTGateService) VStream(ctx context.Context, tabletType topodatapb.T return nil } +// ExecuteMulti is part of the VTGateService interface +func (f *fakeVTGateService) ExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + queries, err := sqlparser.NewTestParser().SplitStatementToPieces(sqlString) + if err != nil { + return session, nil, err + } + var result *sqltypes.Result + for _, query := range queries { + session, result, err = f.Execute(ctx, mysqlCtx, session, query, nil, false) + if err != nil { + return session, qrs, err + } + qrs = append(qrs, result) + } + return session, qrs, nil +} + +// StreamExecuteMulti is part of the VTGateService interface +func (f *fakeVTGateService) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + queries, err := sqlparser.NewTestParser().SplitStatementToPieces(sqlString) + if err != nil { + return session, err + } + for idx, query := range queries { + firstPacket := true + session, err = f.StreamExecute(ctx, mysqlCtx, session, query, nil, func(result *sqltypes.Result) error { + err = callback(sqltypes.QueryResponse{QueryResult: result}, idx < len(queries)-1, firstPacket) + firstPacket = false + return err + }) + if err != nil { + return session, err + } + } + return session, nil +} + // HandlePanic is part of the VTGateService interface func (f *fakeVTGateService) HandlePanic(err *error) { if x := recover(); x != nil { diff --git a/go/vt/vtgate/fakerpcvtgateconn/conn.go b/go/vt/vtgate/fakerpcvtgateconn/conn.go index 894ac5e2193..259b68e06ba 100644 --- a/go/vt/vtgate/fakerpcvtgateconn/conn.go +++ b/go/vt/vtgate/fakerpcvtgateconn/conn.go @@ -115,6 +115,16 @@ func (conn *FakeVTGateConn) ExecuteBatch(ctx context.Context, session *vtgatepb. panic("not implemented") } +// ExecuteMulti please see vtgateconn.Impl.ExecuteBatch +func (conn *FakeVTGateConn) ExecuteMulti(ctx context.Context, session *vtgatepb.Session, sqlString string) (*vtgatepb.Session, []*sqltypes.Result, error) { + panic("not implemented") +} + +// StreamExecuteMulti please see vtgateconn.Impl.ExecuteBatch. +func (conn *FakeVTGateConn) StreamExecuteMulti(ctx context.Context, session *vtgatepb.Session, sqlString string, processResponse func(response *vtgatepb.StreamExecuteMultiResponse)) (sqltypes.MultiResultStream, error) { + panic("not implemented") +} + // StreamExecute please see vtgateconn.Impl.StreamExecute func (conn *FakeVTGateConn) StreamExecute(ctx context.Context, session *vtgatepb.Session, sql string, bindVars map[string]*querypb.BindVariable, _ func(response *vtgatepb.StreamExecuteResponse)) (sqltypes.ResultStream, error) { response, ok := conn.execMap[sql] diff --git a/go/vt/vtgate/grpcvtgateconn/conn.go b/go/vt/vtgate/grpcvtgateconn/conn.go index f37f61fd9cc..86967de9662 100644 --- a/go/vt/vtgate/grpcvtgateconn/conn.go +++ b/go/vt/vtgate/grpcvtgateconn/conn.go @@ -207,6 +207,69 @@ func (conn *vtgateConn) StreamExecute(ctx context.Context, session *vtgatepb.Ses }, nil } +// ExecuteMulti executes multiple non-streaming queries. +func (conn *vtgateConn) ExecuteMulti(ctx context.Context, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + request := &vtgatepb.ExecuteMultiRequest{ + CallerId: callerid.EffectiveCallerIDFromContext(ctx), + Session: session, + Sql: sqlString, + } + response, err := conn.c.ExecuteMulti(ctx, request) + if err != nil { + return session, nil, vterrors.FromGRPC(err) + } + return response.Session, sqltypes.Proto3ToResults(response.Results), vterrors.FromVTRPC(response.Error) +} + +type streamExecuteMultiAdapter struct { + recv func() (*querypb.QueryResult, bool, error) + fields []*querypb.Field +} + +func (a *streamExecuteMultiAdapter) Recv() (*sqltypes.Result, bool, error) { + var qr *querypb.QueryResult + var err error + var newResult bool + for { + qr, newResult, err = a.recv() + if qr != nil || err != nil { + break + } + // we reach here, only when it is the last packet. + // as in the last packet we receive the session and there is no result + } + if err != nil { + return nil, newResult, err + } + if qr != nil && qr.Fields != nil { + a.fields = qr.Fields + } + return sqltypes.CustomProto3ToResult(a.fields, qr), newResult, nil +} + +// StreamExecuteMulti executes multiple streaming queries. +func (conn *vtgateConn) StreamExecuteMulti(ctx context.Context, session *vtgatepb.Session, sqlString string, processResponse func(response *vtgatepb.StreamExecuteMultiResponse)) (sqltypes.MultiResultStream, error) { + req := &vtgatepb.StreamExecuteMultiRequest{ + CallerId: callerid.EffectiveCallerIDFromContext(ctx), + Sql: sqlString, + Session: session, + } + stream, err := conn.c.StreamExecuteMulti(ctx, req) + if err != nil { + return nil, vterrors.FromGRPC(err) + } + return &streamExecuteMultiAdapter{ + recv: func() (*querypb.QueryResult, bool, error) { + ser, err := stream.Recv() + if err != nil { + return nil, false, vterrors.FromGRPC(err) + } + processResponse(ser) + return ser.Result.GetResult(), ser.NewResult, vterrors.FromVTRPC(ser.Result.GetError()) + }, + }, nil +} + func (conn *vtgateConn) Prepare(ctx context.Context, session *vtgatepb.Session, query string) (*vtgatepb.Session, []*querypb.Field, uint16, error) { request := &vtgatepb.PrepareRequest{ CallerId: callerid.EffectiveCallerIDFromContext(ctx), diff --git a/go/vt/vtgate/grpcvtgateconn/suite_test.go b/go/vt/vtgate/grpcvtgateconn/suite_test.go index 4ae478a78c0..064d11021cc 100644 --- a/go/vt/vtgate/grpcvtgateconn/suite_test.go +++ b/go/vt/vtgate/grpcvtgateconn/suite_test.go @@ -39,6 +39,7 @@ import ( topodatapb "vitess.io/vitess/go/vt/proto/topodata" vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/vtgateconn" "vitess.io/vitess/go/vt/vtgate/vtgateservice" @@ -106,6 +107,7 @@ func (f *fakeVTGateService) Execute( panic(fmt.Errorf("test forced panic")) } f.checkCallerID(ctx, "Execute") + sql = strings.TrimSpace(sql) execCase, ok := execMap[sql] if !ok { return session, nil, fmt.Errorf("no match for: %s", sql) @@ -163,6 +165,7 @@ func (f *fakeVTGateService) StreamExecute(ctx context.Context, mysqlCtx vtgatese if f.panics { panic(fmt.Errorf("test forced panic")) } + sql = strings.TrimSpace(sql) execCase, ok := execMap[sql] if !ok { return session, fmt.Errorf("no match for: %s", sql) @@ -200,9 +203,49 @@ func (f *fakeVTGateService) StreamExecute(ctx context.Context, mysqlCtx vtgatese } } } + if execCase.outSession == nil { + return session, nil + } return execCase.outSession, nil } +// ExecuteMulti is part of the VTGateService interface +func (f *fakeVTGateService) ExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + queries, err := sqlparser.NewTestParser().SplitStatementToPieces(sqlString) + if err != nil { + return session, nil, err + } + var result *sqltypes.Result + for _, query := range queries { + session, result, err = f.Execute(ctx, mysqlCtx, session, query, nil, false) + if err != nil { + return session, qrs, err + } + qrs = append(qrs, result) + } + return session, qrs, nil +} + +// StreamExecuteMulti is part of the VTGateService interface +func (f *fakeVTGateService) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + queries, err := sqlparser.NewTestParser().SplitStatementToPieces(sqlString) + if err != nil { + return session, err + } + for idx, query := range queries { + firstPacket := true + session, err = f.StreamExecute(ctx, mysqlCtx, session, query, nil, func(result *sqltypes.Result) error { + err = callback(sqltypes.QueryResponse{QueryResult: result}, idx < len(queries)-1, firstPacket) + firstPacket = false + return err + }) + if err != nil { + return session, err + } + } + return session, nil +} + // Prepare is part of the VTGateService interface func (f *fakeVTGateService) Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) { if f.hasError { @@ -279,15 +322,19 @@ func RunTests(t *testing.T, impl vtgateconn.Impl, fakeServer vtgateservice.VTGat fs := fakeServer.(*fakeVTGateService) testExecute(t, session) + testExecuteMulti(t, session) testStreamExecute(t, session) + testStreamExecuteMulti(t, session) testExecuteBatch(t, session) testPrepare(t, session) // force a panic at every call, then test that works fs.panics = true testExecutePanic(t, session) + testExecuteMultiPanic(t, session) testExecuteBatchPanic(t, session) testStreamExecutePanic(t, session) + testStreamExecuteMultiPanic(t, session) testPreparePanic(t, session) fs.panics = false } @@ -360,6 +407,27 @@ func testExecute(t *testing.T, session *vtgateconn.VTGateSession) { } } +func testExecuteMulti(t *testing.T, session *vtgateconn.VTGateSession) { + ctx := newContext() + execCase := execMap["request1"] + multiQuery := fmt.Sprintf("%s; %s", execCase.execQuery.SQL, execCase.execQuery.SQL) + qrs, err := session.ExecuteMulti(ctx, multiQuery) + require.NoError(t, err) + require.Len(t, qrs, 2) + require.True(t, qrs[0].Equal(execCase.result)) + require.True(t, qrs[1].Equal(execCase.result)) + + qrs, err = session.ExecuteMulti(ctx, "none; request1") + require.ErrorContains(t, err, "no match for: none") + require.Nil(t, qrs) + + // Check that we get a single result if we have an error in the second query + qrs, err = session.ExecuteMulti(ctx, "request1; none") + require.ErrorContains(t, err, "no match for: none") + require.Len(t, qrs, 1) + require.True(t, qrs[0].Equal(execCase.result)) +} + func testExecuteError(t *testing.T, session *vtgateconn.VTGateSession, fake *fakeVTGateService) { ctx := newContext() execCase := execMap["errorRequst"] @@ -375,6 +443,14 @@ func testExecutePanic(t *testing.T, session *vtgateconn.VTGateSession) { expectPanic(t, err) } +func testExecuteMultiPanic(t *testing.T, session *vtgateconn.VTGateSession) { + ctx := newContext() + execCase := execMap["request1"] + multiQuery := fmt.Sprintf("%s; %s", execCase.execQuery.SQL, execCase.execQuery.SQL) + _, err := session.ExecuteMulti(ctx, multiQuery) + expectPanic(t, err) +} + func testExecuteBatch(t *testing.T, session *vtgateconn.VTGateSession) { ctx := newContext() execCase := execMap["request1"] @@ -448,6 +524,73 @@ func testStreamExecute(t *testing.T, session *vtgateconn.VTGateSession) { } } +func testStreamExecuteMulti(t *testing.T, session *vtgateconn.VTGateSession) { + ctx := newContext() + execCase := execMap["request1"] + multiQuery := fmt.Sprintf("%s; %s", execCase.execQuery.SQL, execCase.execQuery.SQL) + stream, err := session.StreamExecuteMulti(ctx, multiQuery) + require.NoError(t, err) + var qr *sqltypes.Result + var qrs []*sqltypes.Result + for { + packet, newRes, err := stream.Recv() + if err != nil { + if err != io.EOF { + t.Error(err) + } + break + } + if newRes { + if qr != nil { + qrs = append(qrs, qr) + } + qr = &sqltypes.Result{} + } + if len(packet.Fields) != 0 { + qr.Fields = packet.Fields + } + if len(packet.Rows) != 0 { + qr.Rows = append(qr.Rows, packet.Rows...) + } + } + if qr != nil { + qrs = append(qrs, qr) + } + wantResult := execCase.result.Copy() + wantResult.RowsAffected = 0 + wantResult.InsertID = 0 + wantResult.InsertIDChanged = false + require.NoError(t, err) + require.Len(t, qrs, 2) + require.True(t, qrs[0].Equal(wantResult)) + require.True(t, qrs[1].Equal(wantResult)) + + stream, err = session.StreamExecuteMulti(ctx, "none; request1") + require.NoError(t, err) + qr, _, err = stream.Recv() + require.ErrorContains(t, err, "no match for: none") + require.Nil(t, qr) + + stream, err = session.StreamExecuteMulti(ctx, "request1; none") + require.NoError(t, err) + var packet *sqltypes.Result + qr = &sqltypes.Result{} + for { + packet, _, err = stream.Recv() + if err != nil { + break + } + if len(packet.Fields) != 0 { + qr.Fields = packet.Fields + } + if len(packet.Rows) != 0 { + qr.Rows = append(qr.Rows, packet.Rows...) + } + } + require.ErrorContains(t, err, "no match for: none") + require.True(t, qr.Equal(wantResult)) +} + func testStreamExecuteError(t *testing.T, session *vtgateconn.VTGateSession, fake *fakeVTGateService) { ctx := newContext() execCase := execMap["request1"] @@ -487,6 +630,17 @@ func testStreamExecutePanic(t *testing.T, session *vtgateconn.VTGateSession) { expectPanic(t, err) } +func testStreamExecuteMultiPanic(t *testing.T, session *vtgateconn.VTGateSession) { + ctx := newContext() + execCase := execMap["request1"] + multiQuery := fmt.Sprintf("%s; %s", execCase.execQuery.SQL, execCase.execQuery.SQL) + stream, err := session.StreamExecuteMulti(ctx, multiQuery) + require.NoError(t, err) + _, _, err = stream.Recv() + require.Error(t, err) + expectPanic(t, err) +} + func testPrepare(t *testing.T, session *vtgateconn.VTGateSession) { ctx := newContext() execCase := execMap["request1"] diff --git a/go/vt/vtgate/grpcvtgateservice/server.go b/go/vt/vtgate/grpcvtgateservice/server.go index 46c0bc8f242..04790328c06 100644 --- a/go/vt/vtgate/grpcvtgateservice/server.go +++ b/go/vt/vtgate/grpcvtgateservice/server.go @@ -154,6 +154,62 @@ func (vtg *VTGate) Execute(ctx context.Context, request *vtgatepb.ExecuteRequest }, nil } +// ExecuteMulti is the RPC version of vtgateservice.VTGateService method +func (vtg *VTGate) ExecuteMulti(ctx context.Context, request *vtgatepb.ExecuteMultiRequest) (response *vtgatepb.ExecuteMultiResponse, err error) { + defer vtg.server.HandlePanic(&err) + ctx = withCallerIDContext(ctx, request.CallerId) + + // Handle backward compatibility. + session := request.Session + if session == nil { + session = &vtgatepb.Session{Autocommit: true} + } + newSession, qrs, err := vtg.server.ExecuteMulti(ctx, nil, session, request.Sql) + return &vtgatepb.ExecuteMultiResponse{ + Results: sqltypes.ResultsToProto3(qrs), + Session: newSession, + Error: vterrors.ToVTRPC(err), + }, nil +} + +func (vtg *VTGate) StreamExecuteMulti(request *vtgatepb.StreamExecuteMultiRequest, stream vtgateservicepb.Vitess_StreamExecuteMultiServer) (err error) { + defer vtg.server.HandlePanic(&err) + ctx := withCallerIDContext(stream.Context(), request.CallerId) + + session := request.Session + if session == nil { + session = &vtgatepb.Session{Autocommit: true} + } + + session, vtgErr := vtg.server.StreamExecuteMulti(ctx, nil, session, request.Sql, func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error { + // Send is not safe to call concurrently, but vtgate + // guarantees that it's not. + return stream.Send(&vtgatepb.StreamExecuteMultiResponse{ + Result: sqltypes.QueryResponseToProto3(qr), + MoreResults: more, + NewResult: firstPacket, + }) + }) + + var errs []error + if vtgErr != nil { + errs = append(errs, vtgErr) + } + + if sendSessionInStreaming { + // even if there is an error, session could have been modified. + // So, this needs to be sent back to the client. Session is sent in the last stream response. + lastErr := stream.Send(&vtgatepb.StreamExecuteMultiResponse{ + Session: session, + }) + if lastErr != nil { + errs = append(errs, lastErr) + } + } + + return vterrors.ToGRPC(vterrors.Aggregate(errs)) +} + // ExecuteBatch is the RPC version of vtgateservice.VTGateService method func (vtg *VTGate) ExecuteBatch(ctx context.Context, request *vtgatepb.ExecuteBatchRequest) (response *vtgatepb.ExecuteBatchResponse, err error) { defer vtg.server.HandlePanic(&err) diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 512a278e0ff..7e285f157a4 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -267,6 +267,89 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq return callback(result) } +// ComQueryMulti is a newer version of ComQuery that supports running multiple queries in a single call. +func (vh *vtgateHandler) ComQueryMulti(c *mysql.Conn, sql string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) error { + session := vh.session(c) + if c.IsShuttingDown() && !session.InTransaction { + c.MarkForClose() + return sqlerror.NewSQLError(sqlerror.ERServerShutdown, sqlerror.SSNetError, "Server shutdown in progress") + } + + ctx, cancel := context.WithCancel(context.Background()) + c.UpdateCancelCtx(cancel) + + span, ctx, err := startSpan(ctx, sql, "vtgateHandler.ComQueryMulti") + if err != nil { + return vterrors.Wrap(err, "failed to extract span") + } + defer span.Finish() + + ctx = callinfo.MysqlCallInfo(ctx, c) + + // Fill in the ImmediateCallerID with the UserData returned by + // the AuthServer plugin for that user. If nothing was + // returned, use the User. This lets the plugin map a MySQL + // user used for authentication to a Vitess User used for + // Table ACLs and Vitess authentication in general. + im := c.UserData.Get() + ef := callerid.NewEffectiveCallerID( + c.User, /* principal: who */ + c.RemoteAddr().String(), /* component: running client process */ + "VTGate MySQL Connector" /* subcomponent: part of the client */) + ctx = callerid.NewContext(ctx, ef, im) + + if !session.InTransaction { + vh.busyConnections.Add(1) + } + defer func() { + if !session.InTransaction { + vh.busyConnections.Add(-1) + } + }() + + if session.Options.Workload == querypb.ExecuteOptions_OLAP { + if c.Capabilities&mysql.CapabilityClientMultiStatements != 0 { + session, err = vh.vtg.StreamExecuteMulti(ctx, vh, session, sql, callback) + } else { + firstPacket := true + session, err = vh.vtg.StreamExecute(ctx, vh, session, sql, make(map[string]*querypb.BindVariable), func(result *sqltypes.Result) error { + defer func() { + firstPacket = false + }() + return callback(sqltypes.QueryResponse{QueryResult: result}, false, firstPacket) + }) + } + if err != nil { + return sqlerror.NewSQLErrorFromError(err) + } + fillInTxStatusFlags(c, session) + return nil + } + var results []*sqltypes.Result + var result *sqltypes.Result + var queryResults []sqltypes.QueryResponse + if c.Capabilities&mysql.CapabilityClientMultiStatements != 0 { + session, results, err = vh.vtg.ExecuteMulti(ctx, vh, session, sql) + for _, res := range results { + queryResults = append(queryResults, sqltypes.QueryResponse{QueryResult: res}) + } + if err != nil { + queryResults = append(queryResults, sqltypes.QueryResponse{QueryError: sqlerror.NewSQLErrorFromError(err)}) + } + } else { + session, result, err = vh.vtg.Execute(ctx, vh, session, sql, make(map[string]*querypb.BindVariable), false) + queryResults = append(queryResults, sqltypes.QueryResponse{QueryResult: result, QueryError: sqlerror.NewSQLErrorFromError(err)}) + } + + fillInTxStatusFlags(c, session) + for idx, res := range queryResults { + if callbackErr := callback(res, idx < len(queryResults)-1, true); callbackErr != nil { + return callbackErr + } + } + return nil +} + func fillInTxStatusFlags(c *mysql.Conn, session *vtgatepb.Session) { if session.InTransaction { c.StatusFlags |= mysql.ServerStatusInTrans diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 73636f3463b..a311790d771 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -19,6 +19,7 @@ package vtgate import ( "context" "crypto/tls" + "errors" "fmt" "os" "path" @@ -31,11 +32,14 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/replication" + "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/trace" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/tlstest" "vitess.io/vitess/go/vt/vtenv" ) @@ -58,6 +62,25 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return callback(&sqltypes.Result{Fields: []*querypb.Field{}, Rows: [][]sqltypes.Value{}}) } +func (th *testHandler) ComQueryMulti(c *mysql.Conn, sql string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) error { + qries, err := th.Env().Parser().SplitStatementToPieces(sql) + if err != nil { + return err + } + for i, query := range qries { + firstPacket := true + err = th.ComQuery(c, query, func(result *sqltypes.Result) error { + err = callback(sqltypes.QueryResponse{QueryResult: result}, i < len(qries)-1, firstPacket) + firstPacket = false + return err + }) + if err != nil { + return err + } + } + return nil +} + func (th *testHandler) ComPrepare(*mysql.Conn, string) ([]*querypb.Field, uint16, error) { return nil, 0, nil } @@ -346,6 +369,457 @@ func TestKillMethods(t *testing.T) { require.True(t, mysqlConn.IsMarkedForClose()) } +func TestComQueryMulti(t *testing.T) { + testcases := []struct { + name string + sql string + olap bool + queryResponses []sqltypes.QueryResponse + more []bool + firstPacket []bool + errExpected bool + }{ + { + name: "Empty query", + sql: "", + queryResponses: []sqltypes.QueryResponse{ + {QueryResult: nil, QueryError: sqlerror.NewSQLErrorFromError(sqlparser.ErrEmpty)}, + }, + more: []bool{false}, + firstPacket: []bool{true}, + errExpected: false, + }, { + name: "Single query", + sql: "select 1", + queryResponses: []sqltypes.QueryResponse{ + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(1), + }, + }, + }, + QueryError: nil, + }, + }, + more: []bool{false}, + firstPacket: []bool{true}, + errExpected: false, + }, { + name: "Multiple queries - success", + sql: "select 1; select 2; select 3;", + queryResponses: []sqltypes.QueryResponse{ + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(1), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "2", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(2), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "3", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(3), + }, + }, + }, + QueryError: nil, + }, + }, + more: []bool{true, true, false}, + firstPacket: []bool{true, true, true}, + errExpected: false, + }, { + name: "Multiple queries - failure", + sql: "select 1; select 2; parsing error; select 3;", + queryResponses: []sqltypes.QueryResponse{ + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(1), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "2", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(2), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: nil, + QueryError: errors.New("syntax error at position 8 near 'parsing' (errno 1105) (sqlstate HY000)"), + }, + }, + more: []bool{true, true, false}, + firstPacket: []bool{true, true, true}, + errExpected: false, + }, { + name: "Empty query - olap", + sql: "", + olap: true, + queryResponses: []sqltypes.QueryResponse{}, + more: []bool{false}, + firstPacket: []bool{true}, + errExpected: true, + }, { + name: "Single query - olap", + sql: "select 1", + olap: true, + queryResponses: []sqltypes.QueryResponse{ + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(1), + }, + }, + }, + QueryError: nil, + }, + }, + more: []bool{false, false, false}, + firstPacket: []bool{true, false, false}, + errExpected: false, + }, { + name: "Multiple queries - olap - success", + sql: "select 1; select 2; select 3;", + olap: true, + queryResponses: []sqltypes.QueryResponse{ + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(1), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "2", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "2", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(2), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "3", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "3", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(3), + }, + }, + }, + QueryError: nil, + }, + }, + more: []bool{true, true, true, true, true, true, false, false, false}, + firstPacket: []bool{true, false, false, true, false, false, true, false, false}, + errExpected: false, + }, { + name: "Multiple queries - olap - failure", + sql: "select 1; select 2; parsing error; select 3;", + olap: true, + queryResponses: []sqltypes.QueryResponse{ + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "1", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(1), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "2", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: "2", + Type: sqltypes.Int64, + Flags: uint32(querypb.MySqlFlag_NUM_FLAG | querypb.MySqlFlag_NOT_NULL_FLAG), + Charset: collations.CollationBinaryID, + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: &sqltypes.Result{ + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(2), + }, + }, + }, + QueryError: nil, + }, + { + QueryResult: nil, + QueryError: errors.New("syntax error at position 8 near 'parsing' (errno 1105) (sqlstate HY000)"), + }, + }, + more: []bool{true, true, true, true, true, true, false}, + firstPacket: []bool{true, false, false, true, false, false, true}, + errExpected: false, + }, + } + + executor, _, _, _, _ := createExecutorEnv(t) + th := &testHandler{} + listener, err := mysql.NewListener("tcp", "127.0.0.1:", mysql.NewAuthServerNone(), th, 0, 0, false, false, 0, 0) + require.NoError(t, err) + defer listener.Close() + + // add a connection + mysqlConn := mysql.GetTestServerConn(listener) + mysqlConn.ConnectionID = 1 + mysqlConn.UserData = &mysql.StaticUserData{} + mysqlConn.Capabilities = mysqlConn.Capabilities | mysql.CapabilityClientMultiStatements + vh := newVtgateHandler(newVTGate(executor, nil, nil, nil, nil)) + vh.connections[1] = mysqlConn + for _, tt := range testcases { + t.Run(tt.name, func(t *testing.T) { + vh.session(mysqlConn).Options.Workload = querypb.ExecuteOptions_OLTP + if tt.olap { + vh.session(mysqlConn).Options.Workload = querypb.ExecuteOptions_OLAP + } + idx := 0 + err = vh.ComQueryMulti(mysqlConn, tt.sql, func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error { + assert.True(t, tt.queryResponses[idx].QueryResult.Equal(qr.QueryResult), "Result Got: %v", qr.QueryResult) + if tt.queryResponses[idx].QueryError != nil { + assert.Equal(t, tt.queryResponses[idx].QueryError.Error(), qr.QueryError.Error(), "Error Got: %v", qr.QueryError) + } else { + assert.Nil(t, qr.QueryError, "Error Got: %v", qr.QueryError) + } + assert.Equal(t, tt.more[idx], more, idx) + assert.Equal(t, tt.firstPacket[idx], firstPacket, idx) + idx++ + return nil + }) + assert.Equal(t, tt.errExpected, err != nil) + assert.Equal(t, len(tt.queryResponses), idx) + }) + } +} + func TestGracefulShutdown(t *testing.T) { executor, _, _, _, _ := createExecutorEnv(t) @@ -365,6 +839,10 @@ func TestGracefulShutdown(t *testing.T) { return nil }) assert.NoError(t, err) + err = vh.ComQueryMulti(mysqlConn, "select 1", func(res sqltypes.QueryResponse, more bool, firstPacket bool) error { + return nil + }) + assert.NoError(t, err) listener.Shutdown() @@ -372,6 +850,10 @@ func TestGracefulShutdown(t *testing.T) { return nil }) require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") + err = vh.ComQueryMulti(mysqlConn, "select 1", func(res sqltypes.QueryResponse, more bool, firstPacket bool) error { + return nil + }) + require.EqualError(t, err, "Server shutdown in progress (errno 1053) (sqlstate 08S01)") require.True(t, mysqlConn.IsMarkedForClose()) } diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 76774e0a8f6..6f522e55f06 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -32,6 +32,7 @@ import ( "github.com/spf13/viper" "vitess.io/vitess/go/acl" + "vitess.io/vitess/go/mysql/sqlerror" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/stats" "vitess.io/vitess/go/tb" @@ -583,6 +584,38 @@ func (vtg *VTGate) Execute( return session, nil, err } +// ExecuteMulti executes multiple non-streaming queries. +func (vtg *VTGate) ExecuteMulti( + ctx context.Context, + mysqlCtx vtgateservice.MySQLConnection, + session *vtgatepb.Session, + sqlString string, +) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) { + queries, err := vtg.executor.Environment().Parser().SplitStatementToPieces(sqlString) + if err != nil { + return session, nil, err + } + if len(queries) == 0 { + return session, nil, sqlparser.ErrEmpty + } + var qr *sqltypes.Result + var cancel context.CancelFunc + for _, query := range queries { + func() { + if mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(ctx, mysqlQueryTimeout) + defer cancel() + } + session, qr, err = vtg.Execute(ctx, mysqlCtx, session, query, make(map[string]*querypb.BindVariable), false) + }() + if err != nil { + return session, qrs, err + } + qrs = append(qrs, qr) + } + return session, qrs, nil +} + // ExecuteBatch executes a batch of queries. func (vtg *VTGate) ExecuteBatch(ctx context.Context, session *vtgatepb.Session, sqlList []string, bindVariablesList []map[string]*querypb.BindVariable) (*vtgatepb.Session, []sqltypes.QueryResponse, error) { // In this context, we don't care if we can't fully parse destination @@ -650,6 +683,47 @@ func (vtg *VTGate) StreamExecute(ctx context.Context, mysqlCtx vtgateservice.MyS return safeSession.Session, nil } +// StreamExecuteMulti executes a streaming query. +// Note we guarantee the callback will not be called concurrently by multiple go routines. +func (vtg *VTGate) StreamExecuteMulti(ctx context.Context, mysqlCtx vtgateservice.MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) { + queries, err := vtg.executor.Environment().Parser().SplitStatementToPieces(sqlString) + if err != nil { + return session, err + } + if len(queries) == 0 { + return session, sqlparser.ErrEmpty + } + var cancel context.CancelFunc + firstPacket := true + more := true + for idx, query := range queries { + firstPacket = true + more = idx < len(queries)-1 + func() { + if mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(ctx, mysqlQueryTimeout) + defer cancel() + } + session, err = vtg.StreamExecute(ctx, mysqlCtx, session, query, make(map[string]*querypb.BindVariable), func(result *sqltypes.Result) error { + defer func() { + firstPacket = false + }() + return callback(sqltypes.QueryResponse{QueryResult: result}, more, firstPacket) + }) + }() + if err != nil { + // We got an error before we sent a single packet. So it must be an error + // because of the query itself. We should return the error in the packet and stop + // processing any more queries. + if firstPacket { + return session, callback(sqltypes.QueryResponse{QueryError: sqlerror.NewSQLErrorFromError(err)}, false, true) + } + return session, err + } + } + return session, nil +} + // CloseSession closes the session, rolling back any implicit transactions. This has the // same effect as if a "rollback" statement was executed, but does not affect the query // statistics. diff --git a/go/vt/vtgate/vtgateconn/vtgateconn.go b/go/vt/vtgate/vtgateconn/vtgateconn.go index 7455bd6dd88..994fd176d91 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn.go @@ -131,6 +131,13 @@ func (sn *VTGateSession) ExecuteBatch(ctx context.Context, query []string, bindV return res, errs } +// ExecuteMulti performs a VTGate ExecuteMulti. +func (sn *VTGateSession) ExecuteMulti(ctx context.Context, query string) ([]*sqltypes.Result, error) { + session, res, err := sn.impl.ExecuteMulti(ctx, sn.session, query) + sn.session = session + return res, err +} + // StreamExecute executes a streaming query on vtgate. // It returns a ResultStream and an error. First check the // error. Then you can pull values from the ResultStream until io.EOF, @@ -144,6 +151,19 @@ func (sn *VTGateSession) StreamExecute(ctx context.Context, query string, bindVa }) } +// StreamExecuteMulti executes a set of streaming queries on vtgate. +// It returns a MultiResultStream and an error. First check the +// error. Then you can pull values from the MultiResultStream until io.EOF, +// or another error. The boolean field tells you when a new result starts. +func (sn *VTGateSession) StreamExecuteMulti(ctx context.Context, query string) (sqltypes.MultiResultStream, error) { + // passing in the function that will update the session when received on the stream. + return sn.impl.StreamExecuteMulti(ctx, sn.session, query, func(response *vtgatepb.StreamExecuteMultiResponse) { + if response.Session != nil { + sn.session = response.Session + } + }) +} + // Prepare performs a VTGate Prepare. func (sn *VTGateSession) Prepare(ctx context.Context, query string) ([]*querypb.Field, uint16, error) { session, fields, paramsCount, err := sn.impl.Prepare(ctx, sn.session, query) @@ -167,6 +187,12 @@ type Impl interface { // StreamExecute executes a streaming query on vtgate. StreamExecute(ctx context.Context, session *vtgatepb.Session, query string, bindVars map[string]*querypb.BindVariable, processResponse func(*vtgatepb.StreamExecuteResponse)) (sqltypes.ResultStream, error) + // ExecuteMulti executes multiple non-streaming queries. + ExecuteMulti(ctx context.Context, session *vtgatepb.Session, sqlString string) (*vtgatepb.Session, []*sqltypes.Result, error) + + // StreamExecuteMulti executes multiple streaming queries. + StreamExecuteMulti(ctx context.Context, session *vtgatepb.Session, sqlString string, processResponse func(response *vtgatepb.StreamExecuteMultiResponse)) (sqltypes.MultiResultStream, error) + // Prepare returns the fields information for the query as part of supporting prepare statements. Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) diff --git a/go/vt/vtgate/vtgateservice/interface.go b/go/vt/vtgate/vtgateservice/interface.go index e97020651d5..5e9414f8819 100644 --- a/go/vt/vtgate/vtgateservice/interface.go +++ b/go/vt/vtgate/vtgateservice/interface.go @@ -37,6 +37,12 @@ type VTGateService interface { // Prepare statement support Prepare(ctx context.Context, session *vtgatepb.Session, sql string) (*vtgatepb.Session, []*querypb.Field, uint16, error) + // ExecuteMulti executes multiple non-streaming queries. + ExecuteMulti(ctx context.Context, mysqlCtx MySQLConnection, session *vtgatepb.Session, sqlString string) (newSession *vtgatepb.Session, qrs []*sqltypes.Result, err error) + + // StreamExecuteMulti executes multiple streaming queries. + StreamExecuteMulti(ctx context.Context, mysqlCtx MySQLConnection, session *vtgatepb.Session, sqlString string, callback func(qr sqltypes.QueryResponse, more bool, firstPacket bool) error) (*vtgatepb.Session, error) + // CloseSession closes the session, rolling back any implicit transactions. // This has the same effect as if a "rollback" statement was executed, // but does not affect the query statistics. diff --git a/proto/vtgate.proto b/proto/vtgate.proto index f94724f2c96..314c3ecda52 100644 --- a/proto/vtgate.proto +++ b/proto/vtgate.proto @@ -177,6 +177,34 @@ message ReadAfterWrite { bool session_track_gtids = 3; } +// ExecuteMultiRequest is the payload to ExecuteMulti. +message ExecuteMultiRequest { + // caller_id identifies the caller. This is the effective caller ID, + // set by the application to further identify the caller. + vtrpc.CallerID caller_id = 1; + + // sql contains the set of queries to execute. + string sql = 2; + + // session carries the session state. + Session session = 3; +} + +// ExecuteMultiResponse is the returned value from ExecuteMulti. +message ExecuteMultiResponse { + // error contains an application level error if necessary. Note the + // session may have changed, even when an error is returned (for + // instance if a database integrity error happened). + vtrpc.RPCError error = 1; + + // session is the updated session information. + Session session = 2; + + // results contain the query results. There can be some results even if the + // error is set. + repeated query.QueryResult results = 3; +} + // ExecuteRequest is the payload to Execute. message ExecuteRequest { // caller_id identifies the caller. This is the effective caller ID, @@ -282,6 +310,34 @@ message StreamExecuteResponse { Session session = 2; } +// StreamExecuteMultiRequest is the payload to StreamExecuteMulti. +message StreamExecuteMultiRequest { + // caller_id identifies the caller. This is the effective caller ID, + // set by the application to further identify the caller. + vtrpc.CallerID caller_id = 1; + + // sql contains the set of queries to execute. + string sql = 2; + + // session carries the session state. + Session session = 3; +} + +// StreamExecuteMultiResponse is the returned value from StreamExecuteMulti. +message StreamExecuteMultiResponse { + // result contains the result set or an error if one occurred while executing the query. + query.ResultWithError result = 1; + + // more_results is set to true if there are more results to follow after this one has concluded. + bool more_results = 2; + + // new_result signifies a new result has started with this packet. + bool new_result = 3; + + // session is the updated session information. + Session session = 4; +} + // ResolveTransactionRequest is the payload to ResolveTransaction. message ResolveTransactionRequest { // caller_id identifies the caller. This is the effective caller ID, diff --git a/proto/vtgateservice.proto b/proto/vtgateservice.proto index fe6170b3ecc..557ff5e9223 100644 --- a/proto/vtgateservice.proto +++ b/proto/vtgateservice.proto @@ -35,6 +35,9 @@ service Vitess { // API group: v3 rpc Execute(vtgate.ExecuteRequest) returns (vtgate.ExecuteResponse) {}; + // ExecuteMulti executes multiple queries on the right shards. + rpc ExecuteMulti(vtgate.ExecuteMultiRequest) returns (vtgate.ExecuteMultiResponse) {}; + // ExecuteBatch tries to route the list of queries on the right shards. // It depends on the query and bind variables to provide enough // information in conjunction with the vindexes to route the query. @@ -48,6 +51,9 @@ service Vitess { // API group: v3 rpc StreamExecute(vtgate.StreamExecuteRequest) returns (stream vtgate.StreamExecuteResponse) {}; + // StreamExecuteMulti executes multiple streaming queries. + rpc StreamExecuteMulti(vtgate.StreamExecuteMultiRequest) returns (stream vtgate.StreamExecuteMultiResponse) {}; + // VStream streams binlog events from the requested sources. rpc VStream(vtgate.VStreamRequest) returns (stream vtgate.VStreamResponse) {}; diff --git a/test/config.json b/test/config.json index 01f9c9be2a4..3326ec34ba5 100644 --- a/test/config.json +++ b/test/config.json @@ -561,6 +561,15 @@ "RetryMax": 1, "Tags": ["upgrade_downgrade_query_serving_queries"] }, + "vtgate_queries_multi_query": { + "File": "unused.go", + "Args": ["vitess.io/vitess/go/test/endtoend/vtgate/queries/multi_query"], + "Command": [], + "Manual": false, + "Shard": "vtgate_queries", + "RetryMax": 1, + "Tags": ["upgrade_downgrade_query_serving_queries"] + }, "vtgate_queries_timeout": { "File": "unused.go", "Args": ["vitess.io/vitess/go/test/endtoend/vtgate/queries/timeout"],