Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions go/vt/vtexplain/vtexplain_vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, vSchemaStr, ksShar
vte.explainTopo.TopoServer = memorytopo.NewServer(ctx, vtexplainCell)
vte.healthCheck = discovery.NewFakeHealthCheck(nil)

resolver := vte.newFakeResolver(opts, vte.explainTopo, vtexplainCell)
resolver := vte.newFakeResolver(ctx, opts, vte.explainTopo, vtexplainCell)

err := vte.buildTopology(ctx, opts, vSchemaStr, ksShardMapStr, opts.NumShards)
if err != nil {
Expand All @@ -80,10 +80,9 @@ func (vte *VTExplain) initVtgateExecutor(ctx context.Context, vSchemaStr, ksShar
return nil
}

func (vte *VTExplain) newFakeResolver(opts *Options, serv srvtopo.Server, cell string) *vtgate.Resolver {
ctx := context.Background()
func (vte *VTExplain) newFakeResolver(ctx context.Context, opts *Options, serv srvtopo.Server, cell string) *vtgate.Resolver {
gw := vtgate.NewTabletGateway(ctx, vte.healthCheck, serv, cell)
_ = gw.WaitForTablets([]topodatapb.TabletType{topodatapb.TabletType_REPLICA})
_ = gw.WaitForTablets(ctx, []topodatapb.TabletType{topodatapb.TabletType_REPLICA})

txMode := vtgatepb.TransactionMode_MULTI
if opts.ExecutionMode == ModeTwoPC {
Expand Down
20 changes: 9 additions & 11 deletions go/vt/vtgate/legacy_scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ func TestLegacyExecuteFailOnAutocommit(t *testing.T) {
}

func TestScatterConnExecuteMulti(t *testing.T) {
testScatterConnGeneric(t, "TestScatterConnExecuteMultiShard", func(sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
ctx := utils.LeakCheckContext(t)
testScatterConnGeneric(t, "TestScatterConnExecuteMultiShard", func(ctx context.Context, sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa")
rss, err := res.ResolveDestination(ctx, "TestScatterConnExecuteMultiShard", topodatapb.TabletType_REPLICA, key.DestinationShards(shards))
if err != nil {
Expand All @@ -130,8 +129,7 @@ func TestScatterConnExecuteMulti(t *testing.T) {
}

func TestScatterConnStreamExecuteMulti(t *testing.T) {
testScatterConnGeneric(t, "TestScatterConnStreamExecuteMulti", func(sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
ctx := utils.LeakCheckContext(t)
testScatterConnGeneric(t, "TestScatterConnStreamExecuteMulti", func(ctx context.Context, sc *ScatterConn, shards []string) (*sqltypes.Result, error) {
res := srvtopo.NewResolver(newSandboxForCells(ctx, []string{"aa"}), sc.gateway, "aa")
rss, err := res.ResolveDestination(ctx, "TestScatterConnStreamExecuteMulti", topodatapb.TabletType_REPLICA, key.DestinationShards(shards))
if err != nil {
Expand All @@ -158,15 +156,15 @@ func verifyScatterConnError(t *testing.T, err error, wantErr string, wantCode vt
assert.Equal(t, wantCode, vterrors.Code(err))
}

func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, shards []string) (*sqltypes.Result, error)) {
func testScatterConnGeneric(t *testing.T, name string, f func(ctx context.Context, sc *ScatterConn, shards []string) (*sqltypes.Result, error)) {
ctx := utils.LeakCheckContext(t)

hc := discovery.NewFakeHealthCheck(nil)

// no shard
s := createSandbox(name)
sc := newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
qr, err := f(sc, nil)
qr, err := f(ctx, sc, nil)
require.NoError(t, err)
if qr.RowsAffected != 0 {
t.Errorf("want 0, got %v", qr.RowsAffected)
Expand All @@ -177,7 +175,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sc = newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc := hc.AddTestTablet("aa", "0", 1, name, "0", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
_, err = f(sc, []string{"0"})
_, err = f(ctx, sc, []string{"0"})
want := fmt.Sprintf("target: %v.0.replica: INVALID_ARGUMENT error", name)
// Verify server error string.
if err == nil || err.Error() != want {
Expand All @@ -196,7 +194,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sbc1 := hc.AddTestTablet("aa", "1", 1, name, "1", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc0.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
sbc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
_, err = f(sc, []string{"0", "1"})
_, err = f(ctx, sc, []string{"0", "1"})
// Verify server errors are consolidated.
want = fmt.Sprintf("target: %v.0.replica: INVALID_ARGUMENT error\ntarget: %v.1.replica: INVALID_ARGUMENT error", name, name)
verifyScatterConnError(t, err, want, vtrpcpb.Code_INVALID_ARGUMENT)
Expand All @@ -216,7 +214,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sbc1 = hc.AddTestTablet("aa", "1", 1, name, "1", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc0.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
sbc1.MustFailCodes[vtrpcpb.Code_RESOURCE_EXHAUSTED] = 1
_, err = f(sc, []string{"0", "1"})
_, err = f(ctx, sc, []string{"0", "1"})
// Verify server errors are consolidated.
want = fmt.Sprintf("target: %v.0.replica: INVALID_ARGUMENT error\ntarget: %v.1.replica: RESOURCE_EXHAUSTED error", name, name)
// We should only surface the higher priority error code
Expand All @@ -234,7 +232,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
hc.Reset()
sc = newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc = hc.AddTestTablet("aa", "0", 1, name, "0", topodatapb.TabletType_REPLICA, true, 1, nil)
_, _ = f(sc, []string{"0", "0"})
_, _ = f(ctx, sc, []string{"0", "0"})
// Ensure that we executed only once.
if execCount := sbc.ExecCount.Load(); execCount != 1 {
t.Errorf("want 1, got %v", execCount)
Expand All @@ -246,7 +244,7 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
sc = newTestScatterConn(ctx, hc, newSandboxForCells(ctx, []string{"aa"}), "aa")
sbc0 = hc.AddTestTablet("aa", "0", 1, name, "0", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc1 = hc.AddTestTablet("aa", "1", 1, name, "1", topodatapb.TabletType_REPLICA, true, 1, nil)
qr, err = f(sc, []string{"0", "1"})
qr, err = f(ctx, sc, []string{"0", "1"})
if err != nil {
t.Fatalf("want nil, got %v", err)
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/tabletgateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ func (gw *TabletGateway) RegisterStats() {
}

// WaitForTablets is part of the Gateway interface.
func (gw *TabletGateway) WaitForTablets(tabletTypesToWait []topodatapb.TabletType) (err error) {
func (gw *TabletGateway) WaitForTablets(ctx context.Context, tabletTypesToWait []topodatapb.TabletType) (err error) {
log.Infof("Gateway waiting for serving tablets of types %v ...", tabletTypesToWait)
ctx, cancel := context.WithTimeout(context.Background(), initialTabletTimeout)
ctx, cancel := context.WithTimeout(ctx, initialTabletTimeout)
defer cancel()

defer func() {
Expand Down
61 changes: 32 additions & 29 deletions go/vt/vtgate/tabletgateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,49 +37,55 @@ import (
)

func TestTabletGatewayExecute(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(context.Background(), target, "query", nil, 0, 0, nil)
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(ctx, target, "query", nil, 0, 0, nil)
return err
})
testTabletGatewayTransact(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(context.Background(), target, "query", nil, 1, 0, nil)
testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Execute(ctx, target, "query", nil, 1, 0, nil)
return err
})
}

func TestTabletGatewayExecuteStream(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
err := tg.StreamExecute(context.Background(), target, "query", nil, 0, 0, nil, func(qr *sqltypes.Result) error {
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
err := tg.StreamExecute(ctx, target, "query", nil, 0, 0, nil, func(qr *sqltypes.Result) error {
return nil
})
return err
})
}

func TestTabletGatewayBegin(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Begin(context.Background(), target, nil)
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Begin(ctx, target, nil)
return err
})
}

func TestTabletGatewayCommit(t *testing.T) {
testTabletGatewayTransact(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Commit(context.Background(), target, 1)
ctx := utils.LeakCheckContext(t)
testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Commit(ctx, target, 1)
return err
})
}

func TestTabletGatewayRollback(t *testing.T) {
testTabletGatewayTransact(t, func(tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Rollback(context.Background(), target, 1)
ctx := utils.LeakCheckContext(t)
testTabletGatewayTransact(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, err := tg.Rollback(ctx, target, 1)
return err
})
}

func TestTabletGatewayBeginExecute(t *testing.T) {
testTabletGatewayGeneric(t, func(tg *TabletGateway, target *querypb.Target) error {
_, _, err := tg.BeginExecute(context.Background(), target, nil, "query", nil, 0, nil)
ctx := utils.LeakCheckContext(t)
testTabletGatewayGeneric(t, ctx, func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error {
_, _, err := tg.BeginExecute(ctx, target, nil, "query", nil, 0, nil)
return err
})
}
Expand Down Expand Up @@ -167,14 +173,12 @@ func TestTabletGatewayReplicaTransactionError(t *testing.T) {
defer tg.Close(ctx)

_ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
_, err := tg.Execute(context.Background(), target, "query", nil, 1, 0, nil)
_, err := tg.Execute(ctx, target, "query", nil, 1, 0, nil)
verifyContainsError(t, err, "query service can only be used for non-transactional queries on replicas", vtrpcpb.Code_INTERNAL)
}

func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *querypb.Target) error) {
func testTabletGatewayGeneric(t *testing.T, ctx context.Context, f func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error) {
t.Helper()
ctx := utils.LeakCheckContext(t)

keyspace := "ks"
shard := "0"
tabletType := topodatapb.TabletType_REPLICA
Expand All @@ -192,19 +196,19 @@ func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *qu

// no tablet
want := []string{"target: ks.0.replica", `no healthy tablet available for 'keyspace:"ks" shard:"0" tablet_type:REPLICA`}
err := f(tg, target)
err := f(ctx, tg, target)
verifyShardErrors(t, err, want, vtrpcpb.Code_UNAVAILABLE)

// tablet with error
hc.Reset()
hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, false, 10, fmt.Errorf("no connection"))
err = f(tg, target)
err = f(ctx, tg, target)
verifyShardErrors(t, err, want, vtrpcpb.Code_UNAVAILABLE)

// tablet without connection
hc.Reset()
_ = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, false, 10, nil).Tablet()
err = f(tg, target)
err = f(ctx, tg, target)
verifyShardErrors(t, err, want, vtrpcpb.Code_UNAVAILABLE)

// retry error
Expand All @@ -214,7 +218,7 @@ func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *qu
sc1.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
sc2.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1

err = f(tg, target)
err = f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.replica", vtrpcpb.Code_FAILED_PRECONDITION)

// fatal error
Expand All @@ -223,26 +227,25 @@ func testTabletGatewayGeneric(t *testing.T, f func(tg *TabletGateway, target *qu
sc2 = hc.AddTestTablet("cell", host, port+1, keyspace, shard, tabletType, true, 10, nil)
sc1.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
sc2.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
err = f(tg, target)
err = f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.replica", vtrpcpb.Code_FAILED_PRECONDITION)

// server error - no retry
hc.Reset()
sc1 = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
sc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
err = f(tg, target)
err = f(ctx, tg, target)
assert.Equal(t, vtrpcpb.Code_INVALID_ARGUMENT, vterrors.Code(err))

// no failure
hc.Reset()
hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
err = f(tg, target)
err = f(ctx, tg, target)
assert.NoError(t, err)
}

func testTabletGatewayTransact(t *testing.T, f func(tg *TabletGateway, target *querypb.Target) error) {
func testTabletGatewayTransact(t *testing.T, ctx context.Context, f func(ctx context.Context, tg *TabletGateway, target *querypb.Target) error) {
t.Helper()
ctx := utils.LeakCheckContext(t)

keyspace := "ks"
shard := "0"
Expand All @@ -267,14 +270,14 @@ func testTabletGatewayTransact(t *testing.T, f func(tg *TabletGateway, target *q
sc1.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1
sc2.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 1

err := f(tg, target)
err := f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.primary", vtrpcpb.Code_FAILED_PRECONDITION)

// server error - no retry
hc.Reset()
sc1 = hc.AddTestTablet("cell", host, port, keyspace, shard, tabletType, true, 10, nil)
sc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
err = f(tg, target)
err = f(ctx, tg, target)
verifyContainsError(t, err, "target: ks.0.primary", vtrpcpb.Code_INVALID_ARGUMENT)
}

Expand Down
Loading