Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions go/test/endtoend/vtgate/setstatement/sysvar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,77 @@ func TestStartTxAndSetSystemVariableAndThenSuccessfulCommit(t *testing.T) {
assertMatches(t, conn, "select @@sql_safe_updates", "[[INT64(1)]]")
}

func TestSetSystemVarAutocommitWithConnError(t *testing.T) {
vtParams := mysql.ConnParams{
Host: "localhost",
Port: clusterInstance.VtgateMySQLPort,
}

conn, err := mysql.Connect(context.Background(), &vtParams)
require.NoError(t, err)
defer conn.Close()

checkedExec(t, conn, "delete from test")
checkedExec(t, conn, "insert into test (id, val1) values (1, null), (4, null)")

checkedExec(t, conn, "set sql_safe_updates = 1") // this should force us into a reserved connection
assertMatches(t, conn, "select id from test order by id", "[[INT64(1)] [INT64(4)]]")
qr := checkedExec(t, conn, "select connection_id() from test where id = 1")

// kill the mysql connection shard which has transaction open.
vttablet1 := clusterInstance.Keyspaces[0].Shards[0].MasterTablet() // -80
_, err = vttablet1.VttabletProcess.QueryTablet(fmt.Sprintf("kill %s", qr.Rows[0][0].ToString()), keyspaceName, false)
require.NoError(t, err)

// first query to 80- shard should pass
assertMatches(t, conn, "select id, val1 from test where id = 4", "[[INT64(4) NULL]]")

// first query to -80 shard will fail
_, err = exec(t, conn, "insert into test (id, val1) values (2, null)")
require.Error(t, err)

// subsequent queries on -80 will pass
assertMatches(t, conn, "select id from test where id = 2", "[]")
assertMatches(t, conn, "insert into test (id, val1) values (2, null)", "[]")
assertMatches(t, conn, "select id, @@sql_safe_updates from test where id = 2", "[[INT64(2) INT64(1)]]")
}

func TestSetSystemVarInTxWithConnError(t *testing.T) {
vtParams := mysql.ConnParams{
Host: "localhost",
Port: clusterInstance.VtgateMySQLPort,
}

conn, err := mysql.Connect(context.Background(), &vtParams)
require.NoError(t, err)
defer conn.Close()

checkedExec(t, conn, "delete from test")
checkedExec(t, conn, "insert into test (id, val1) values (1, null), (4, null)")

checkedExec(t, conn, "set sql_safe_updates = 1") // this should force us into a reserved connection
qr := checkedExec(t, conn, "select connection_id() from test where id = 4")
checkedExec(t, conn, "begin")
checkedExec(t, conn, "insert into test (id, val1) values (2, null)")

// kill the mysql connection shard which has transaction open.
vttablet1 := clusterInstance.Keyspaces[0].Shards[1].MasterTablet() // 80-
_, err = vttablet1.VttabletProcess.QueryTablet(fmt.Sprintf("kill %s", qr.Rows[0][0].ToString()), keyspaceName, false)
require.NoError(t, err)

// query to -80 shard should pass and remain in transaction.
assertMatches(t, conn, "select id, val1 from test where id = 2", "[[INT64(2) NULL]]")
checkedExec(t, conn, "rollback")
assertMatches(t, conn, "select id, val1 from test where id = 2", "[]")

// first query to 80- shard will fail
_, err = exec(t, conn, "select @@sql_safe_updates from test where id = 4")
require.Error(t, err)

// subsequent queries on 80- will pass
assertMatches(t, conn, "select id, @@sql_safe_updates from test where id = 4", "[[INT64(4) INT64(1)]]")
}

func assertMatches(t *testing.T, conn *mysql.Conn, query, expected string) {
t.Helper()
qr, err := exec(t, conn, query)
Expand Down
7 changes: 6 additions & 1 deletion go/vt/vtgate/legacy_scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,11 @@ func TestLegaceHealthCheckFailsOnReservedConnections(t *testing.T) {
}

func executeOnShards(t *testing.T, res *srvtopo.Resolver, keyspace string, sc *ScatterConn, session *SafeSession, destinations []key.Destination) {
t.Helper()
require.Empty(t, executeOnShardsReturnsErr(t, res, keyspace, sc, session, destinations))
}

func executeOnShardsReturnsErr(t *testing.T, res *srvtopo.Resolver, keyspace string, sc *ScatterConn, session *SafeSession, destinations []key.Destination) error {
t.Helper()
rss, _, err := res.ResolveDestinations(ctx, keyspace, topodatapb.TabletType_REPLICA, nil, destinations)
require.NoError(t, err)
Expand All @@ -372,7 +377,7 @@ func executeOnShards(t *testing.T, res *srvtopo.Resolver, keyspace string, sc *S
}

_, errs := sc.ExecuteMultiShard(ctx, rss, queries, session, false, false)
require.Empty(t, errs)
return vterrors.Aggregate(errs)
}

func TestMultiExecs(t *testing.T) {
Expand Down
48 changes: 48 additions & 0 deletions go/vt/vtgate/safe_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,51 @@ func (session *SafeSession) SetPreQueries() []string {
}
return result
}

//ResetShard reset the shard session for the provided tablet alias.
func (session *SafeSession) ResetShard(tabletAlias *topodatapb.TabletAlias) error {
session.mu.Lock()
defer session.mu.Unlock()

// Always append, in order for rollback to succeed.
switch session.commitOrder {
case vtgatepb.CommitOrder_NORMAL:
newSessions, err := removeShard(tabletAlias, session.ShardSessions)
if err != nil {
return err
}
session.ShardSessions = newSessions
case vtgatepb.CommitOrder_PRE:
newSessions, err := removeShard(tabletAlias, session.PreSessions)
if err != nil {
return err
}
session.PreSessions = newSessions
case vtgatepb.CommitOrder_POST:
newSessions, err := removeShard(tabletAlias, session.PostSessions)
if err != nil {
return err
}
session.PostSessions = newSessions
default:
// Should be unreachable
return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "BUG: SafeSession.ResetShard: unexpected commitOrder")
}
return nil
}

func removeShard(tabletAlias *topodatapb.TabletAlias, sessions []*vtgatepb.Session_ShardSession) ([]*vtgatepb.Session_ShardSession, error) {
idx := -1
for i, session := range sessions {
if proto.Equal(session.TabletAlias, tabletAlias) {
if session.TransactionId != 0 {
return nil, vterrors.New(vtrpcpb.Code_INTERNAL, "BUG: SafeSession.ResetShard: in transaction")
}
idx = i
}
}
if idx == -1 {
return sessions, nil
}
return append(sessions[:idx], sessions[idx+1:]...), nil
}
32 changes: 19 additions & 13 deletions go/vt/vtgate/scatter_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"sync"
"time"

"vitess.io/vitess/go/mysql"

topodatapb "vitess.io/vitess/go/vt/proto/topodata"

"vitess.io/vitess/go/vt/vttablet/queryservice"
Expand Down Expand Up @@ -173,6 +175,7 @@ func (stc *ScatterConn) ExecuteMultiShard(
err error
opts *querypb.ExecuteOptions
alias *topodatapb.TabletAlias
qs queryservice.QueryService
)
transactionID := info.transactionID
reservedID := info.reservedID
Expand All @@ -188,36 +191,30 @@ func (stc *ScatterConn) ExecuteMultiShard(
}
}

qs, err = getQueryService(rs, info)
if err != nil {
return nil, err
}

switch info.actionNeeded {
case nothing:
qs, err := getQueryService(rs, info)
if err != nil {
return nil, err
}
innerqr, err = qs.Execute(ctx, rs.Target, queries[i].Sql, queries[i].BindVariables, info.transactionID, info.reservedID, opts)
if err != nil {
checkAndResetShardSession(info, err, session)
return nil, err
}
case begin:
qs, err := getQueryService(rs, info)
if err != nil {
return nil, err
}
innerqr, transactionID, alias, err = qs.BeginExecute(ctx, rs.Target, session.Savepoints, queries[i].Sql, queries[i].BindVariables, info.reservedID, opts)
if err != nil {
return info.updateTransactionID(transactionID, alias), err
}
case reserve:
qs, err := getQueryService(rs, info)
if err != nil {
return nil, err
}
innerqr, reservedID, alias, err = qs.ReserveExecute(ctx, rs.Target, session.SetPreQueries(), queries[i].Sql, queries[i].BindVariables, info.transactionID, opts)
if err != nil {
return info.updateReservedID(reservedID, alias), err
}
case reserveBegin:
innerqr, transactionID, reservedID, alias, err = rs.Gateway.ReserveBeginExecute(ctx, rs.Target, session.SetPreQueries(), queries[i].Sql, queries[i].BindVariables, opts)
innerqr, transactionID, reservedID, alias, err = qs.ReserveBeginExecute(ctx, rs.Target, session.SetPreQueries(), queries[i].Sql, queries[i].BindVariables, opts)
if err != nil {
return info.updateTransactionAndReservedID(transactionID, reservedID, alias), err
}
Expand All @@ -242,6 +239,15 @@ func (stc *ScatterConn) ExecuteMultiShard(
return qr, allErrors.GetErrors()
}

func checkAndResetShardSession(info *shardActionInfo, err error, session *SafeSession) {
if info.reservedID != 0 && info.transactionID == 0 {
sqlErr := mysql.NewSQLErrorFromError(err).(*mysql.SQLError)
if sqlErr.Number() == mysql.CRServerGone || sqlErr.Number() == mysql.CRServerLost {
session.ResetShard(info.alias)
}
}
}

func getQueryService(rs *srvtopo.ResolvedShard, info *shardActionInfo) (queryservice.QueryService, error) {
_, usingLegacyGw := rs.Gateway.(*DiscoveryGateway)
if usingLegacyGw {
Expand Down
19 changes: 19 additions & 0 deletions go/vt/vtgate/scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/key"

"vitess.io/vitess/go/test/utils"
Expand Down Expand Up @@ -280,3 +281,21 @@ func TestReservedBeginTableDriven(t *testing.T) {
})
}
}

func TestReservedConnFail(t *testing.T) {
keyspace := "keyspace"
createSandbox(keyspace)
hc := discovery.NewFakeHealthCheck()
sc := newTestScatterConn(hc, new(sandboxTopo), "aa")
sbc0 := hc.AddTestTablet("aa", "0", 1, keyspace, "0", topodatapb.TabletType_REPLICA, true, 1, nil)
_ = hc.AddTestTablet("aa", "1", 1, keyspace, "1", topodatapb.TabletType_REPLICA, true, 1, nil)
res := srvtopo.NewResolver(&sandboxTopo{}, sc.gateway, "aa")

session := NewSafeSession(&vtgatepb.Session{InTransaction: false, InReservedConn: true})
destinations := []key.Destination{key.DestinationShard("0")}
executeOnShards(t, res, keyspace, sc, session, destinations)
assert.Equal(t, 1, len(session.ShardSessions))
sbc0.ShardErr = mysql.NewSQLError(mysql.CRServerGone, mysql.SSUnknownSQLState, "lost connection")
_ = executeOnShardsReturnsErr(t, res, keyspace, sc, session, destinations)
assert.Zero(t, len(session.ShardSessions))
}
20 changes: 10 additions & 10 deletions go/vt/vtgate/tx_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func TestTxConnCommitSuccess(t *testing.T) {
}

func TestTxConnReservedCommitSuccess(t *testing.T) {
sc, sbc0, sbc1, rss0, _, rss01 := newLegacyTestTxConnEnv(t, "TestTxConn")
sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConn")
sc.txConn.mode = vtgatepb.TransactionMode_MULTI

// Sequence the executes to ensure commit order
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestTxConnReservedCommitSuccess(t *testing.T) {
},
TransactionId: 1,
ReservedId: 1,
TabletAlias: sbc0.Tablet().Alias,
TabletAlias: sbc1.Tablet().Alias,
}},
}
utils.MustMatch(t, &wantSession, session.Session, "Session")
Expand All @@ -178,7 +178,7 @@ func TestTxConnReservedCommitSuccess(t *testing.T) {
TabletType: topodatapb.TabletType_MASTER,
},
ReservedId: 2,
TabletAlias: sbc0.Tablet().Alias,
TabletAlias: sbc1.Tablet().Alias,
}},
}
utils.MustMatch(t, &wantSession, session.Session, "Session")
Expand Down Expand Up @@ -574,7 +574,7 @@ func TestTxConnCommitOrderSuccess(t *testing.T) {
}

func TestTxConnReservedCommitOrderSuccess(t *testing.T) {
sc, sbc0, sbc1, rss0, rss1, _ := newLegacyTestTxConnEnv(t, "TestTxConn")
sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, "TestTxConn")
sc.txConn.mode = vtgatepb.TransactionMode_MULTI

queries := []*querypb.BoundQuery{{
Expand Down Expand Up @@ -661,7 +661,7 @@ func TestTxConnReservedCommitOrderSuccess(t *testing.T) {
},
TransactionId: 1,
ReservedId: 1,
TabletAlias: sbc0.Tablet().Alias,
TabletAlias: sbc1.Tablet().Alias,
}},
}
utils.MustMatch(t, &wantSession, session.Session, "Session")
Expand Down Expand Up @@ -699,7 +699,7 @@ func TestTxConnReservedCommitOrderSuccess(t *testing.T) {
TabletType: topodatapb.TabletType_MASTER,
},
ReservedId: 2,
TabletAlias: sbc0.Tablet().Alias,
TabletAlias: sbc1.Tablet().Alias,
}},
}
utils.MustMatch(t, &wantSession, session.Session, "Session")
Expand Down Expand Up @@ -857,7 +857,7 @@ func TestTxConnRollback(t *testing.T) {
}

func TestTxConnReservedRollback(t *testing.T) {
sc, sbc0, sbc1, rss0, _, rss01 := newLegacyTestTxConnEnv(t, "TxConnReservedRollback")
sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TxConnReservedRollback")

session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true})
sc.ExecuteMultiShard(ctx, rss0, queries, session, false, false)
Expand Down Expand Up @@ -892,19 +892,19 @@ func TestTxConnReservedRollback(t *testing.T) {
}

func TestTxConnReservedRollbackFailure(t *testing.T) {
sc, sbc0, sbc1, rss0, _, rss01 := newLegacyTestTxConnEnv(t, "TxConnReservedRollback")
sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TxConnReservedRollback")

session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true})
sc.ExecuteMultiShard(ctx, rss0, queries, session, false, false)
sc.ExecuteMultiShard(ctx, rss01, twoQueries, session, false, false)

sbc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
require.Error(t,
assert.Error(t,
sc.txConn.Rollback(ctx, session))
wantSession := vtgatepb.Session{
InReservedConn: true,
Warnings: []*querypb.QueryWarning{{
Message: "rollback encountered an error and connection to all shard for this session is released: Code: INVALID_ARGUMENT\nINVALID_ARGUMENT error\n\ntarget: TxConnReservedRollback.1.master, used tablet: aa-0 (1)",
Message: "rollback encountered an error and connection to all shard for this session is released: Code: INVALID_ARGUMENT\nINVALID_ARGUMENT error\n",
}},
}
utils.MustMatch(t, &wantSession, session.Session, "Session")
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vttablet/sandboxconn/sandboxconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ type SandboxConn struct {

sExecMu sync.Mutex
execMu sync.Mutex

ShardErr error
}

var _ queryservice.QueryService = (*SandboxConn)(nil) // compile-time interface check
Expand All @@ -125,6 +127,9 @@ func (sbc *SandboxConn) getError() error {
sbc.MustFailCodes[code] = count - 1
return vterrors.New(code, fmt.Sprintf("%v error", code))
}
if sbc.ShardErr != nil {
return sbc.ShardErr
}
return nil
}

Expand Down