diff --git a/go/cmd/vtadmin/main.go b/go/cmd/vtadmin/main.go index e77448a090b..227a395f764 100644 --- a/go/cmd/vtadmin/main.go +++ b/go/cmd/vtadmin/main.go @@ -92,7 +92,7 @@ func startTracing(cmd *cobra.Command) { } func run(cmd *cobra.Command, args []string) { - bootSpan, _ := trace.NewSpan(context.Background(), "vtadmin.boot") + bootSpan, ctx := trace.NewSpan(context.Background(), "vtadmin.boot") defer bootSpan.Finish() configs := clusterFileConfig.Combine(defaultClusterConfig, clusterConfigs) @@ -120,7 +120,7 @@ func run(cmd *cobra.Command, args []string) { } for i, cfg := range configs { - cluster, err := cfg.Cluster() + cluster, err := cfg.Cluster(ctx) if err != nil { bootSpan.Finish() fatal(err) diff --git a/go/vt/vtadmin/api.go b/go/vt/vtadmin/api.go index 16ca6a4b474..e35e47e3921 100644 --- a/go/vt/vtadmin/api.go +++ b/go/vt/vtadmin/api.go @@ -233,7 +233,7 @@ func (api *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { if clusterCookie, err := r.Cookie("cluster"); err == nil { urlDecoded, err := url.QueryUnescape(clusterCookie.Value) if err == nil { - c, id, err := dynamic.ClusterFromString(urlDecoded) + c, id, err := dynamic.ClusterFromString(r.Context(), urlDecoded) if id != "" { if err != nil { log.Warningf("failed to extract valid cluster from cookie; attempting to use existing cluster with id=%s; error: %s", id, err) @@ -389,10 +389,6 @@ func (api *API) CreateKeyspace(ctx context.Context, req *vtadminpb.CreateKeyspac return nil, err } - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - ks, err := c.CreateKeyspace(ctx, req.Options) if err != nil { return nil, err @@ -419,10 +415,6 @@ func (api *API) CreateShard(ctx context.Context, req *vtadminpb.CreateShardReque return nil, err } - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - return c.CreateShard(ctx, req.Options) } @@ -442,10 +434,6 @@ func (api *API) DeleteKeyspace(ctx context.Context, req *vtadminpb.DeleteKeyspac return nil, err } - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - return c.DeleteKeyspace(ctx, req.Options) } @@ -465,10 +453,6 @@ func (api *API) DeleteShards(ctx context.Context, req *vtadminpb.DeleteShardsReq return nil, err } - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - return c.DeleteShards(ctx, req.Options) } @@ -958,10 +942,6 @@ func (api *API) DeleteTablet(ctx context.Context, req *vtadminpb.DeleteTabletReq cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.DeleteTablets(ctx, &vtctldatapb.DeleteTabletsRequest{ TabletAliases: []*topodatapb.TabletAlias{ tablet.Tablet.Alias, @@ -991,10 +971,6 @@ func (api *API) ReparentTablet(ctx context.Context, req *vtadminpb.ReparentTable cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - r, err := c.Vtctld.ReparentTablet(ctx, &vtctldatapb.ReparentTabletRequest{ Tablet: tablet.Tablet.Alias, }) @@ -1023,10 +999,6 @@ func (api *API) RunHealthCheck(ctx context.Context, req *vtadminpb.RunHealthChec cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.RunHealthCheck(ctx, &vtctldatapb.RunHealthCheckRequest{ TabletAlias: tablet.Tablet.Alias, }) @@ -1055,10 +1027,6 @@ func (api *API) PingTablet(ctx context.Context, req *vtadminpb.PingTabletRequest cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.PingTablet(ctx, &vtctldatapb.PingTabletRequest{ TabletAlias: tablet.Tablet.Alias, }) @@ -1087,10 +1055,6 @@ func (api *API) SetReadOnly(ctx context.Context, req *vtadminpb.SetReadOnlyReque cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.SetWritable(ctx, &vtctldatapb.SetWritableRequest{ TabletAlias: tablet.Tablet.Alias, Writable: false, @@ -1120,10 +1084,6 @@ func (api *API) SetReadWrite(ctx context.Context, req *vtadminpb.SetReadWriteReq cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.SetWritable(ctx, &vtctldatapb.SetWritableRequest{ TabletAlias: tablet.Tablet.Alias, Writable: true, @@ -1153,10 +1113,6 @@ func (api *API) StartReplication(ctx context.Context, req *vtadminpb.StartReplic cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.StartReplication(ctx, &vtctldatapb.StartReplicationRequest{ TabletAlias: tablet.Tablet.Alias, }) @@ -1185,10 +1141,6 @@ func (api *API) StopReplication(ctx context.Context, req *vtadminpb.StopReplicat cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.StopReplication(ctx, &vtctldatapb.StopReplicationRequest{ TabletAlias: tablet.Tablet.Alias, }) @@ -1263,10 +1215,6 @@ func (api *API) GetVSchema(ctx context.Context, req *vtadminpb.GetVSchemaRequest return nil, nil } - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - return c.GetVSchema(ctx, req.Keyspace) } @@ -1307,11 +1255,6 @@ func (api *API) GetVSchemas(ctx context.Context, req *vtadminpb.GetVSchemasReque cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - rec.RecordError(fmt.Errorf("Vtctld.Dial(cluster = %s): %w", c.ID, err)) - return - } - getKeyspacesSpan, getKeyspacesCtx := trace.NewSpan(ctx, "Cluster.GetKeyspaces") cluster.AnnotateSpan(c, getKeyspacesSpan) @@ -1510,10 +1453,6 @@ func (api *API) RefreshState(ctx context.Context, req *vtadminpb.RefreshStateReq cluster.AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - _, err = c.Vtctld.RefreshState(ctx, &vtctldatapb.RefreshStateRequest{ TabletAlias: tablet.Tablet.Alias, }) @@ -1565,10 +1504,6 @@ func (api *API) ValidateSchemaKeyspace(ctx context.Context, req *vtadminpb.Valid return nil, nil } - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - res, err := c.Vtctld.ValidateSchemaKeyspace(ctx, &vtctldatapb.ValidateSchemaKeyspaceRequest{ Keyspace: req.Keyspace, }) @@ -1594,10 +1529,6 @@ func (api *API) ValidateVersionKeyspace(ctx context.Context, req *vtadminpb.Vali return nil, nil } - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - res, err := c.Vtctld.ValidateVersionKeyspace(ctx, &vtctldatapb.ValidateVersionKeyspaceRequest{ Keyspace: req.Keyspace, }) @@ -1647,10 +1578,6 @@ func (api *API) VTExplain(ctx context.Context, req *vtadminpb.VTExplainRequest) span.Annotate("tablet_alias", topoproto.TabletAliasString(tablet.Tablet.Alias)) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, err - } - var ( wg sync.WaitGroup er concurrency.AllErrorRecorder @@ -1723,9 +1650,7 @@ func (api *API) VTExplain(ctx context.Context, req *vtadminpb.VTExplainRequest) go func(c *cluster.Cluster) { defer wg.Done() - shards, err := c.FindAllShardsInKeyspace(ctx, req.Keyspace, cluster.FindAllShardsInKeyspaceOptions{ - SkipDial: true, - }) + shards, err := c.FindAllShardsInKeyspace(ctx, req.Keyspace, cluster.FindAllShardsInKeyspaceOptions{}) if err != nil { er.RecordError(err) return diff --git a/go/vt/vtadmin/api_test.go b/go/vt/vtadmin/api_test.go index 5b489400390..bf4d99dde71 100644 --- a/go/vt/vtadmin/api_test.go +++ b/go/vt/vtadmin/api_test.go @@ -4856,7 +4856,7 @@ func TestServeHTTP(t *testing.T) { "discovery": "{\"vtctlds\": [{\"host\":{\"fqdn\": \"localhost:15000\", \"hostname\": \"localhost:15999\"}}], \"vtgates\": [{\"host\": {\"hostname\": \"localhost:15991\"}}]}", }, }, - }.Cluster() + }.Cluster(context.Background()) tests := []struct { name string diff --git a/go/vt/vtadmin/cluster/cluster.go b/go/vt/vtadmin/cluster/cluster.go index 0ea0f099196..4648ff0324b 100644 --- a/go/vt/vtadmin/cluster/cluster.go +++ b/go/vt/vtadmin/cluster/cluster.go @@ -80,7 +80,7 @@ type Cluster struct { } // New creates a new Cluster from a Config. -func New(cfg Config) (*Cluster, error) { +func New(ctx context.Context, cfg Config) (*Cluster, error) { cluster := &Cluster{ ID: cfg.ID, Name: cfg.Name, @@ -112,8 +112,15 @@ func New(cfg Config) (*Cluster, error) { return nil, fmt.Errorf("error creating vtctldclient proxy config: %w", err) } + for _, opt := range cfg.vtctldConfigOpts { + vtctldCfg = opt(vtctldCfg) + } + cluster.DB = vtsql.New(vtsqlCfg) - cluster.Vtctld = vtctldclient.New(vtctldCfg) + cluster.Vtctld, err = vtctldclient.New(ctx, vtctldCfg) + if err != nil { + return nil, fmt.Errorf("error creating vtctldclient: %w", err) + } if cfg.TabletFQDNTmplStr != "" { cluster.TabletFQDNTmpl, err = template.New(cluster.ID + "-tablet-fqdn").Parse(cfg.TabletFQDNTmplStr) @@ -370,9 +377,6 @@ func (c *Cluster) DeleteShards(ctx context.Context, req *vtctldatapb.DeleteShard // FindAllShardsInKeyspaceOptions modify the behavior of a cluster's // FindAllShardsInKeyspace method. type FindAllShardsInKeyspaceOptions struct { - // SkipDial indicates that the cluster can assume the vtctldclient has - // already dialed up a connection to a vtctld. - SkipDial bool // skipPool indicates that the caller has already made a successful call to // Acquire on the topoReadPool. It is not exported, because the cluster // pools are not exported, so it's not possible to manually Acquire from @@ -392,12 +396,6 @@ func (c *Cluster) FindAllShardsInKeyspace(ctx context.Context, keyspace string, AnnotateSpan(c, span) span.Annotate("keyspace", keyspace) - if !opts.SkipDial { - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("failed to Dial vtctld for cluster = %s for FindAllShardsInKeyspace: %w", c.ID, err) - } - } - if !opts.skipPool { if err := c.topoReadPool.Acquire(ctx); err != nil { return nil, fmt.Errorf("FindAllShardsInKeyspace(%s) failed to acquire topoReadPool: %w", keyspace, err) @@ -444,10 +442,6 @@ func (c *Cluster) FindWorkflows(ctx context.Context, keyspaces []string, opts Fi AnnotateSpan(c, span) span.Annotate("active_only", opts.ActiveOnly) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("FindWorkflows(cluster = %v, keyspaces = %v, opts = %v) dial failed: %w", c.ID, keyspaces, opts, err) - } - return c.findWorkflows(ctx, keyspaces, opts) } @@ -796,10 +790,6 @@ func (c *Cluster) GetKeyspace(ctx context.Context, name string) (*vtadminpb.Keys AnnotateSpan(c, span) span.Annotate("keyspace", name) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("Vtctld.Dial failed for cluster = %s: %w", c.ID, err) - } - if err := c.topoReadPool.Acquire(ctx); err != nil { return nil, fmt.Errorf("GetKeyspace(%s) failed to acquire topoReadPool: %w", name, err) } @@ -813,7 +803,6 @@ func (c *Cluster) GetKeyspace(ctx context.Context, name string) (*vtadminpb.Keys } shards, err := c.FindAllShardsInKeyspace(ctx, name, FindAllShardsInKeyspaceOptions{ - SkipDial: true, skipPool: true, // we already acquired before making the GetKeyspace call }) if err != nil { @@ -834,10 +823,6 @@ func (c *Cluster) GetKeyspaces(ctx context.Context) ([]*vtadminpb.Keyspace, erro AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("Vtctld.Dial(cluster=%s) failed: %w", c.ID, err) - } - if err := c.topoReadPool.Acquire(ctx); err != nil { return nil, fmt.Errorf("GetKeyspaces() failed to acquire topoReadPool: %w", err) } @@ -861,7 +846,7 @@ func (c *Cluster) GetKeyspaces(ctx context.Context) ([]*vtadminpb.Keyspace, erro go func(i int, ks *vtctldatapb.Keyspace) { defer wg.Done() - shards, err := c.FindAllShardsInKeyspace(ctx, ks.Name, FindAllShardsInKeyspaceOptions{SkipDial: true}) + shards, err := c.FindAllShardsInKeyspace(ctx, ks.Name, FindAllShardsInKeyspaceOptions{}) if err != nil { rec.RecordError(err) return @@ -993,41 +978,12 @@ func (c *Cluster) GetSchema(ctx context.Context, keyspace string, opts GetSchema annotateGetSchemaRequest(opts.BaseRequest, span) vtadminproto.AnnotateSpanWithGetSchemaTableSizeOptions(opts.TableSizeOptions, span) - var ( - wg sync.WaitGroup - rec concurrency.AllErrorRecorder - - tablets []*vtadminpb.Tablet - ) - - // First, dial vtctld and fetch tablets concurrently. - wg.Add(1) - go func() { - defer wg.Done() - - if err := c.Vtctld.Dial(ctx); err != nil { - rec.RecordError(fmt.Errorf("failed to Dial vtctld for cluster = %s for GetSchema: %w", c.ID, err)) - return - } - }() - - wg.Add(1) - go func() { - defer wg.Done() - // Fetch all tablets for the keyspace. - var err error - - tablets, err = c.FindTablets(ctx, func(tablet *vtadminpb.Tablet) bool { - return tablet.Tablet.Keyspace == keyspace - }, -1) - if err != nil { - rec.RecordError(fmt.Errorf("%w for keyspace %s", errors.ErrNoTablet, keyspace)) - } - }() - - wg.Wait() - if rec.HasErrors() { - return nil, rec.Error() + // Fetch all tablets for the keyspace. + tablets, err := c.FindTablets(ctx, func(tablet *vtadminpb.Tablet) bool { + return tablet.Tablet.Keyspace == keyspace + }, -1) + if err != nil { + return nil, fmt.Errorf("%w for keyspace %s", errors.ErrNoTablet, keyspace) } tabletsToQuery, err := c.getTabletsToQueryForSchemas(ctx, keyspace, tablets, opts) @@ -1097,11 +1053,6 @@ func (c *Cluster) GetSchemas(ctx context.Context, opts GetSchemaOptions) ([]*vta span, ctx := trace.NewSpan(ctx, "Cluster.GetKeyspaces") defer span.Finish() - if err := c.Vtctld.Dial(ctx); err != nil { - rec.RecordError(fmt.Errorf("Vtctld.Dial(cluster=%s) failed: %w", c.ID, err)) - return - } - if err := c.topoReadPool.Acquire(ctx); err != nil { rec.RecordError(fmt.Errorf("GetKeyspaces() failed to acquire topoReadPool: %w", err)) return @@ -1292,7 +1243,7 @@ func (c *Cluster) getSchemaFromTablets(ctx context.Context, keyspace string, tab func (c *Cluster) getTabletsToQueryForSchemas(ctx context.Context, keyspace string, tablets []*vtadminpb.Tablet, opts GetSchemaOptions) ([]*vtadminpb.Tablet, error) { if opts.TableSizeOptions.AggregateSizes { - shards, err := c.FindAllShardsInKeyspace(ctx, keyspace, FindAllShardsInKeyspaceOptions{SkipDial: true}) + shards, err := c.FindAllShardsInKeyspace(ctx, keyspace, FindAllShardsInKeyspaceOptions{}) if err != nil { return nil, err } @@ -1422,10 +1373,6 @@ func (c *Cluster) GetSrvVSchema(ctx context.Context, cell string) (*vtadminpb.Sr AnnotateSpan(c, span) span.Annotate("cell", cell) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("Vtctld.Dial(cluster=%s) failed: %w", c.ID, err) - } - if err := c.topoReadPool.Acquire(ctx); err != nil { return nil, fmt.Errorf("GetSrvVSchema(%s) failed to acquire topoReadPool: %w", cell, err) } @@ -1454,10 +1401,6 @@ func (c *Cluster) GetSrvVSchemas(ctx context.Context, cells []string) ([]*vtadmi AnnotateSpan(c, span) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("Vtctld.Dial(cluster=%s) failed: %w", c.ID, err) - } - if err := c.topoReadPool.Acquire(ctx); err != nil { return nil, fmt.Errorf("GetSrvVSchema(cluster = %s, cells = %v) failed to acquire topoReadPool: %w", c.ID, cells, err) } @@ -1548,10 +1491,6 @@ func (c *Cluster) GetWorkflow(ctx context.Context, keyspace string, name string, span.Annotate("keyspace", keyspace) span.Annotate("workflow_name", name) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("GetWorkflow(cluster = %v, keyspace = %v, workflow = %v, opts = %+v) dial failed: %w", c.ID, keyspace, name, opts, err) - } - workflows, err := c.findWorkflows(ctx, []string{keyspace}, FindWorkflowsOptions{ ActiveOnly: opts.ActiveOnly, Filter: func(workflow *vtadminpb.Workflow) bool { @@ -1597,10 +1536,6 @@ func (c *Cluster) GetWorkflows(ctx context.Context, keyspaces []string, opts Get AnnotateSpan(c, span) span.Annotate("active_only", opts.ActiveOnly) - if err := c.Vtctld.Dial(ctx); err != nil { - return nil, fmt.Errorf("GetWorkflows(cluster = %v, keyspaces = %v, opts = %v) dial failed: %w", c.ID, keyspaces, opts, err) - } - return c.findWorkflows(ctx, keyspaces, FindWorkflowsOptions{ ActiveOnly: opts.ActiveOnly, IgnoreKeyspaces: opts.IgnoreKeyspaces, diff --git a/go/vt/vtadmin/cluster/cluster_internal_test.go b/go/vt/vtadmin/cluster/cluster_internal_test.go index 799202c3dc4..81d6cd213ce 100644 --- a/go/vt/vtadmin/cluster/cluster_internal_test.go +++ b/go/vt/vtadmin/cluster/cluster_internal_test.go @@ -121,7 +121,6 @@ func Test_getShardSets(t *testing.T) { }, topoReadPool: pools.NewRPCPool(5, 0, nil), } - require.NoError(t, c.Vtctld.Dial(context.Background())) tests := []struct { name string diff --git a/go/vt/vtadmin/cluster/cluster_test.go b/go/vt/vtadmin/cluster/cluster_test.go index 38293064cc2..c61490f1609 100644 --- a/go/vt/vtadmin/cluster/cluster_test.go +++ b/go/vt/vtadmin/cluster/cluster_test.go @@ -165,8 +165,6 @@ func TestCreateKeyspace(t *testing.T) { t.Parallel() cluster := testutil.BuildCluster(t, tt.cfg) - err := cluster.Vtctld.Dial(ctx) - require.NoError(t, err, "could not dial test vtctld") resp, err := cluster.CreateKeyspace(ctx, tt.req) if tt.shouldErr { @@ -263,10 +261,7 @@ func TestCreateShard(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - err := tt.tc.Cluster.Vtctld.Dial(ctx) - require.NoError(t, err, "could not dial in-process vtctld") - - _, err = tt.tc.Cluster.CreateShard(ctx, tt.req) + _, err := tt.tc.Cluster.CreateShard(ctx, tt.req) if tt.shouldErr { assert.Error(t, err) } else { @@ -356,8 +351,6 @@ func TestDeleteKeyspace(t *testing.T) { t.Parallel() cluster := testutil.BuildCluster(t, tt.cfg) - err := cluster.Vtctld.Dial(ctx) - require.NoError(t, err, "could not dial test vtctld") resp, err := cluster.DeleteKeyspace(ctx, tt.req) if tt.shouldErr { @@ -487,9 +480,6 @@ func TestDeleteShards(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - err := tt.tc.Cluster.Vtctld.Dial(ctx) - require.NoError(t, err, "could not dial in-process vtctld") - if tt.setup != nil { func() { t.Helper() @@ -497,7 +487,7 @@ func TestDeleteShards(t *testing.T) { }() } - _, err = tt.tc.Cluster.DeleteShards(ctx, tt.req) + _, err := tt.tc.Cluster.DeleteShards(ctx, tt.req) if tt.shouldErr { assert.Error(t, err) } else { @@ -1369,9 +1359,6 @@ func TestGetSchema(t *testing.T) { DBConfig: testutil.Dbcfg{}, }) - err := c.Vtctld.Dial(ctx) - require.NoError(t, err, "could not dial test vtctld") - schema, err := c.GetSchema(ctx, "testkeyspace", cluster.GetSchemaOptions{ BaseRequest: tt.req, }) @@ -1423,9 +1410,6 @@ func TestGetSchema(t *testing.T) { VtctldClient: vtctld, }) - err := c.Vtctld.Dial(ctx) - require.NoError(t, err, "could not dial test vtctld") - _, _ = c.GetSchema(ctx, "testkeyspace", cluster.GetSchemaOptions{ BaseRequest: req, }) @@ -2599,8 +2583,6 @@ func TestGetShardReplicationPositions(t *testing.T) { t.Parallel() c := testutil.BuildCluster(t, tt.cfg) - err := c.Vtctld.Dial(ctx) - require.NoError(t, err, "failed to dial test vtctld") resp, err := c.GetShardReplicationPositions(ctx, tt.req) if tt.shouldErr { @@ -2716,8 +2698,6 @@ func TestGetVSchema(t *testing.T) { t.Parallel() cluster := testutil.BuildCluster(t, tt.cfg) - err := cluster.Vtctld.Dial(ctx) - require.NoError(t, err, "could not dial test vtctld") vschema, err := cluster.GetVSchema(ctx, tt.keyspace) if tt.shouldErr { diff --git a/go/vt/vtadmin/cluster/config.go b/go/vt/vtadmin/cluster/config.go index 7a9f6aeb098..fafbbab1764 100644 --- a/go/vt/vtadmin/cluster/config.go +++ b/go/vt/vtadmin/cluster/config.go @@ -17,6 +17,7 @@ limitations under the License. package cluster import ( + "context" "encoding/json" stderrors "errors" "fmt" @@ -29,6 +30,7 @@ import ( "vitess.io/vitess/go/pools" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/vtadmin/errors" + "vitess.io/vitess/go/vt/vtadmin/vtctldclient" ) var ( @@ -61,11 +63,13 @@ type Config struct { TopoRWPoolConfig *RPCPoolConfig TopoReadPoolConfig *RPCPoolConfig WorkflowReadPoolConfig *RPCPoolConfig + + vtctldConfigOpts []vtctldclient.ConfigOption } // Cluster returns a new cluster instance from the given config. -func (cfg Config) Cluster() (*Cluster, error) { - return New(cfg) +func (cfg Config) Cluster(ctx context.Context) (*Cluster, error) { + return New(ctx, cfg) } // String is part of the flag.Value interface. @@ -355,3 +359,13 @@ func (cfg *RPCPoolConfig) parseFlag(name string, val string) (err error) { return nil } + +// WithVtctldTestConfigOptions returns a new Config with the given vtctldclient +// ConfigOptions appended to any existing ConfigOptions in the current Config. +// +// It should be used in tests only, and is exported to for use in the +// vtadmin/testutil package. +func (cfg Config) WithVtctldTestConfigOptions(opts ...vtctldclient.ConfigOption) Config { + cfg.vtctldConfigOpts = append(cfg.vtctldConfigOpts, opts...) + return cfg +} diff --git a/go/vt/vtadmin/cluster/dynamic/cluster.go b/go/vt/vtadmin/cluster/dynamic/cluster.go index c6d5ea2045a..a59706b2aa7 100644 --- a/go/vt/vtadmin/cluster/dynamic/cluster.go +++ b/go/vt/vtadmin/cluster/dynamic/cluster.go @@ -1,6 +1,7 @@ package dynamic import ( + "context" "encoding/base64" "strings" @@ -19,7 +20,7 @@ import ( // // Therefore, callers should handle the return values as follows: // -// c, id, err := dynamic.ClusterFromString(s) +// c, id, err := dynamic.ClusterFromString(ctx, s) // if id == "" { // // handle err, do not use `c`. // } @@ -29,12 +30,12 @@ import ( // // Use `c` (or existing cluster with ID == `id`) based on the dynamic cluster // api.WithCluster(c, id).DoAThing() // -func ClusterFromString(s string) (c *cluster.Cluster, id string, err error) { +func ClusterFromString(ctx context.Context, s string) (c *cluster.Cluster, id string, err error) { cfg, id, err := cluster.LoadConfig(base64.NewDecoder(base64.StdEncoding, strings.NewReader(s)), "json") if err != nil { return nil, id, err } - c, err = cfg.Cluster() + c, err = cfg.Cluster(ctx) return c, id, err } diff --git a/go/vt/vtadmin/cluster/dynamic/cluster_test.go b/go/vt/vtadmin/cluster/dynamic/cluster_test.go index 652dcd2fb80..54b2b46d44b 100644 --- a/go/vt/vtadmin/cluster/dynamic/cluster_test.go +++ b/go/vt/vtadmin/cluster/dynamic/cluster_test.go @@ -1,6 +1,7 @@ package dynamic import ( + "context" "encoding/base32" "encoding/base64" "testing" @@ -62,7 +63,7 @@ func TestClusterFromString(t *testing.T) { enc := tt.encoder([]byte(tt.s)) - c, id, err := ClusterFromString(enc) + c, id, err := ClusterFromString(context.Background(), enc) if tt.shouldErr { assert.Error(t, err) assert.Nil(t, c, "when err != nil, cluster must be nil") diff --git a/go/vt/vtadmin/cluster/dynamic/interceptors.go b/go/vt/vtadmin/cluster/dynamic/interceptors.go index 01fdc93f695..504d83a8736 100644 --- a/go/vt/vtadmin/cluster/dynamic/interceptors.go +++ b/go/vt/vtadmin/cluster/dynamic/interceptors.go @@ -86,7 +86,7 @@ func clusterFromIncomingContextMetadata(ctx context.Context) (*cluster.Cluster, return nil, "", false, nil } - c, id, err := ClusterFromString(clusterMetadata[0]) + c, id, err := ClusterFromString(ctx, clusterMetadata[0]) return c, id, true, err } diff --git a/go/vt/vtadmin/testutil/cluster.go b/go/vt/vtadmin/testutil/cluster.go index ee7674948b8..b435b388864 100644 --- a/go/vt/vtadmin/testutil/cluster.go +++ b/go/vt/vtadmin/testutil/cluster.go @@ -17,6 +17,7 @@ limitations under the License. package testutil import ( + "context" "database/sql" "fmt" "sync" @@ -98,11 +99,16 @@ func BuildCluster(t testing.TB, cfg TestClusterConfig) *cluster.Cluster { ID: cfg.Cluster.Id, Name: cfg.Cluster.Name, DiscoveryImpl: discoveryTestImplName, - } + }.WithVtctldTestConfigOptions(vtadminvtctldclient.WithDialFunc(func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) { + return cfg.VtctldClient, nil + })) m.Lock() testdisco = disco - c, err := cluster.New(clusterConf) + c, err := cluster.New( + context.Background(), // consider updating this function to allow callers to provide a context. + clusterConf, + ) m.Unlock() require.NoError(t, err, "failed to create cluster from configs %+v %+v", clusterConf, cfg) @@ -123,11 +129,6 @@ func BuildCluster(t testing.TB, cfg TestClusterConfig) *cluster.Cluster { return sql.OpenDB(&fakevtsql.Connector{Tablets: tablets, ShouldErr: cfg.DBConfig.ShouldErr}), nil } - vtctld := c.Vtctld.(*vtadminvtctldclient.ClientProxy) - vtctld.DialFunc = func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) { - return cfg.VtctldClient, nil - } - return c } diff --git a/go/vt/vtadmin/vtctldclient/config.go b/go/vt/vtadmin/vtctldclient/config.go index a975a696970..53b6fd83a5c 100644 --- a/go/vt/vtadmin/vtctldclient/config.go +++ b/go/vt/vtadmin/vtctldclient/config.go @@ -20,11 +20,13 @@ import ( "fmt" "github.com/spf13/pflag" + "google.golang.org/grpc" "vitess.io/vitess/go/vt/grpcclient" "vitess.io/vitess/go/vt/vtadmin/cluster/discovery" "vitess.io/vitess/go/vt/vtadmin/cluster/resolver" "vitess.io/vitess/go/vt/vtadmin/credentials" + "vitess.io/vitess/go/vt/vtctl/vtctldclient" vtadminpb "vitess.io/vitess/go/vt/proto/vtadmin" ) @@ -37,6 +39,24 @@ type Config struct { Cluster *vtadminpb.Cluster ResolverOptions *resolver.Options + + dialFunc func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) +} + +// ConfigOption is a function that mutates a Config. It should return the same +// Config structure, in a builder-pattern style. +type ConfigOption func(cfg *Config) *Config + +// WithDialFunc returns a ConfigOption that applies the given dial function to +// a Config. +// +// It is used to support dependency injection in tests, and needs to be exported +// for higher-level tests (via vtadmin/testutil). +func WithDialFunc(f func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error)) ConfigOption { + return func(cfg *Config) *Config { + cfg.dialFunc = f + return cfg + } } // Parse returns a new config with the given cluster and discovery, after diff --git a/go/vt/vtadmin/vtctldclient/proxy.go b/go/vt/vtadmin/vtctldclient/proxy.go index 85cafb130e4..bb5ad20ebb9 100644 --- a/go/vt/vtadmin/vtctldclient/proxy.go +++ b/go/vt/vtadmin/vtctldclient/proxy.go @@ -40,14 +40,10 @@ import ( // Proxy defines the connection interface of a proxied vtctldclient used by // VTAdmin clusters. type Proxy interface { - // Dial opens a gRPC connection to a vtctld in the cluster. If the Proxy - // already has a valid connection, this is a no-op. - Dial(ctx context.Context) error - // Close closes the underlying vtctldclient connection. This is a no-op if // the Proxy has no current, valid connection. It is safe to call repeatedly. - // Users may call Dial on a previously-closed Proxy to create a new - // connection, but that connection may not be to the same particular vtctld. + // + // Once closed, a proxy is not safe for reuse. Close() error vtctlservicepb.VtctldClient @@ -65,7 +61,7 @@ type ClientProxy struct { // DialFunc is called to open a new vtctdclient connection. In production, // this should always be grpcvtctldclient.NewWithDialOpts, but it is // exported for testing purposes. - DialFunc func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) + dialFunc func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) resolver grpcresolver.Builder m sync.Mutex @@ -80,43 +76,37 @@ type ClientProxy struct { // // It does not open a connection to a vtctld; users must call Dial before first // use. -func New(cfg *Config) *ClientProxy { - return &ClientProxy{ +func New(ctx context.Context, cfg *Config) (*ClientProxy, error) { + dialFunc := cfg.dialFunc + if dialFunc == nil { + dialFunc = grpcvtctldclient.NewWithDialOpts + } + + proxy := ClientProxy{ cfg: cfg, cluster: cfg.Cluster, creds: cfg.Credentials, - DialFunc: grpcvtctldclient.NewWithDialOpts, + dialFunc: dialFunc, resolver: cfg.ResolverOptions.NewBuilder(cfg.Cluster.Id), closed: true, } + + if err := proxy.dial(ctx); err != nil { + return nil, err + } + + return &proxy, nil } -// Dial is part of the Proxy interface. -func (vtctld *ClientProxy) Dial(ctx context.Context) error { +// dial invokes a grpc.Dial call with the discovery-backed resolver for vtctlds +// in the proxy's cluster. +// +// it is called once at ClientProxy instantiation (in New()). +func (vtctld *ClientProxy) dial(ctx context.Context) error { span, _ := trace.NewSpan(ctx, "VtctldClientProxy.Dial") defer span.Finish() vtadminproto.AnnotateClusterSpan(vtctld.cluster, span) - - vtctld.m.Lock() - defer vtctld.m.Unlock() - - if vtctld.VtctldClient != nil { - if !vtctld.closed { - span.Annotate("is_noop", true) - return nil - } - - span.Annotate("is_stale", true) - - if err := vtctld.closeLocked(); err != nil { - // Even if the client connection does not shut down cleanly, we don't want to block - // Dial from discovering a new vtctld. This makes VTAdmin's dialer more resilient, - // but, as a caveat, it _can_ potentially leak improperly-closed gRPC connections. - log.Errorf("error closing possibly-stale connection before re-dialing: %w", err) - } - } - span.Annotate("is_using_credentials", vtctld.creds != nil) opts := []grpc.DialOption{ @@ -132,13 +122,17 @@ func (vtctld *ClientProxy) Dial(ctx context.Context) error { opts = append(opts, grpc.WithResolvers(vtctld.resolver)) - // TODO: update DialFunc to take ctx as first arg. - client, err := vtctld.DialFunc(resolver.DialAddr(vtctld.resolver, "vtctld"), grpcclient.FailFast(false), opts...) + // TODO: update dialFunc to take ctx as first arg. + client, err := vtctld.dialFunc(resolver.DialAddr(vtctld.resolver, "vtctld"), grpcclient.FailFast(false), opts...) if err != nil { return err } log.Infof("Established gRPC connection to vtctld\n") + + vtctld.m.Lock() + defer vtctld.m.Unlock() + vtctld.dialedAt = time.Now() vtctld.VtctldClient = client vtctld.closed = false @@ -151,28 +145,19 @@ func (vtctld *ClientProxy) Close() error { vtctld.m.Lock() defer vtctld.m.Unlock() - return vtctld.closeLocked() -} - -func (vtctld *ClientProxy) closeLocked() error { if vtctld.VtctldClient == nil { vtctld.closed = true return nil } - err := vtctld.VtctldClient.Close() - + // TODO: (ajm188) Figure out if this comment is still accurate. // Mark the vtctld connection as "closed" from the proxy side even if // the client connection does not shut down cleanly. This makes VTAdmin's dialer more resilient, // but, as a caveat, it _can_ potentially leak improperly-closed gRPC connections. - vtctld.closed = true + defer func() { vtctld.closed = true }() - if err != nil { - return err - } - - return nil + return vtctld.VtctldClient.Close() } // Debug implements debug.Debuggable for ClientProxy. diff --git a/go/vt/vtadmin/vtctldclient/proxy_test.go b/go/vt/vtadmin/vtctldclient/proxy_test.go index 235ff8eafe0..520ab48ab1d 100644 --- a/go/vt/vtadmin/vtctldclient/proxy_test.go +++ b/go/vt/vtadmin/vtctldclient/proxy_test.go @@ -19,13 +19,13 @@ package vtctldclient import ( "context" "net" + "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" - grpcresolver "google.golang.org/grpc/resolver" "vitess.io/vitess/go/vt/vtadmin/cluster/discovery/fakediscovery" "vitess.io/vitess/go/vt/vtadmin/cluster/resolver" @@ -80,7 +80,7 @@ func TestDial(t *testing.T) { Hostname: listener.Addr().String(), }) - proxy := New(&Config{ + proxy, err := New(context.Background(), &Config{ Cluster: &vtadminpb.Cluster{ Id: "test", Name: "testcluster", @@ -90,42 +90,34 @@ func TestDial(t *testing.T) { DiscoveryTimeout: 50 * time.Millisecond, }, }) - defer proxy.Close() // prevents grpc-core from logging a bunch of "connection errors" after deferred listener.Close() above. + require.NoError(t, err) - err = proxy.Dial(context.Background()) - assert.NoError(t, err) + defer proxy.Close() // prevents grpc-core from logging a bunch of "connection errors" after deferred listener.Close() above. resp, err := proxy.GetKeyspace(context.Background(), &vtctldatapb.GetKeyspaceRequest{}) require.NoError(t, err) assert.Equal(t, listener.Addr().String(), resp.Keyspace.Name) } -// testResolverBuilder wraps a grpcresolver.Builder to return *testResolvers -// with a channel to detect calls to ResolveNow in tests. -type testResolverBuilder struct { - grpcresolver.Builder - fired chan struct{} +type testdisco struct { + *fakediscovery.Fake + notify chan struct{} + fired chan struct{} + m sync.Mutex } -func (b *testResolverBuilder) Build(target grpcresolver.Target, cc grpcresolver.ClientConn, opts grpcresolver.BuildOptions) (grpcresolver.Resolver, error) { - r, err := b.Builder.Build(target, cc, opts) - if err != nil { - return nil, err - } - - return &testResolver{r, b.fired}, nil -} +func (d *testdisco) DiscoverVtctldAddrs(ctx context.Context, tags []string) ([]string, error) { + d.m.Lock() + defer d.m.Unlock() -// testResolver wraps a grpcresolver.Resolver to signal when ResolveNow is -// called in tests. -type testResolver struct { - grpcresolver.Resolver - fired chan struct{} -} - -func (r *testResolver) ResolveNow(o grpcresolver.ResolveNowOptions) { - r.Resolver.ResolveNow(o) - r.fired <- struct{}{} + select { + case <-d.notify: + defer func() { + go func() { d.fired <- struct{}{} }() + }() + default: + } + return d.Fake.DiscoverVtctldAddrs(ctx, tags) } // TestRedial tests that vtadmin-api is able to recover from a lost connection to @@ -149,32 +141,33 @@ func TestRedial(t *testing.T) { go server2.Serve(listener2) defer server2.Stop() + reResolveFired := make(chan struct{}, 1) + // Register both vtctlds with VTAdmin - disco := fakediscovery.New() + disco := &testdisco{ + Fake: fakediscovery.New(), + notify: make(chan struct{}), + fired: reResolveFired, + } disco.AddTaggedVtctlds(nil, &vtadminpb.Vtctld{ Hostname: listener1.Addr().String(), }, &vtadminpb.Vtctld{ Hostname: listener2.Addr().String(), }) - reResolveFired := make(chan struct{}) - proxy := New(&Config{ + proxy, err := New(context.Background(), &Config{ Cluster: &vtadminpb.Cluster{ Id: "test", Name: "testcluster", }, ResolverOptions: &resolver.Options{ - Discovery: disco, - DiscoveryTimeout: 50 * time.Millisecond, + Discovery: disco, + DiscoveryTimeout: 50 * time.Millisecond, + MinDiscoveryInterval: 0, + BackoffStrategy: "none", }, }) - - // wrap the resolver builder to test that re-resolve has fired as expected. - proxy.resolver = &testResolverBuilder{Builder: proxy.resolver, fired: reResolveFired} - - // Check for a successful connection to whichever vtctld we discover first. - err = proxy.Dial(context.Background()) - assert.NoError(t, err) + require.NoError(t, err) // vtadmin's fakediscovery package discovers vtctlds in random order. Rather // than force some cumbersome sequential logic, we can just do a switcheroo @@ -182,6 +175,7 @@ func TestRedial(t *testing.T) { var currentVtctld *grpc.Server var nextAddr string + // Check for a successful connection to whichever vtctld we discover first. resp, err := proxy.GetKeyspace(context.Background(), &vtctldatapb.GetKeyspaceRequest{}) require.NoError(t, err) @@ -198,29 +192,38 @@ func TestRedial(t *testing.T) { t.Fatalf("invalid proxy host %s", proxyHost) } - // Remove the shut down vtctld from VTAdmin's service discovery (clumsily). - // Otherwise, when redialing, we may redial the vtctld that we just shut down. + // Shut down the vtctld we're connected to, then await re-resolution. + + // 1. First, block calls to DiscoverVtctldAddrs so we don't race with the + // background resolver watcher. + disco.m.Lock() + + // 2. Force an ungraceful shutdown of the gRPC server to which we're + // connected. + currentVtctld.Stop() + + // 3. Remove the shut down vtctld from VTAdmin's service discovery + // (clumsily). Otherwise, when redialing, we may redial the vtctld that we + // just shut down. disco.Clear() disco.AddTaggedVtctlds(nil, &vtadminpb.Vtctld{ Hostname: nextAddr, }) - // Force an ungraceful shutdown of the gRPC server to which we're connected - currentVtctld.Stop() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() + // 4. Notify our wrapped DiscoverVtctldAddrs function to start signaling on + // its `fired` channel when called. + close(disco.notify) + // 5. Unblock calls to DiscoverVtctldAddrs, and move on to our assertions. + disco.m.Unlock() + maxWait := time.Second select { case <-reResolveFired: - case <-ctx.Done(): - require.FailNowf(t, "forced shutdown of vtctld should trigger grpc re-resolution", ctx.Err().Error()) + case <-time.After(maxWait): + require.FailNowf(t, "forced shutdown of vtctld should trigger grpc re-resolution", "did not receive re-resolve signal within %s", maxWait) } - // Finally, check that we discover, dial + establish a new connection to the remaining vtctld. - err = proxy.Dial(context.Background()) - assert.NoError(t, err) - + // Finally, check that we discover + establish a new connection to the remaining vtctld. resp, err = proxy.GetKeyspace(context.Background(), &vtctldatapb.GetKeyspaceRequest{}) require.NoError(t, err) assert.Equal(t, nextAddr, resp.Keyspace.Name)