diff --git a/go/vt/vtadmin/cluster/cluster.go b/go/vt/vtadmin/cluster/cluster.go index 4648ff0324b..d3bb91fb6e8 100644 --- a/go/vt/vtadmin/cluster/cluster.go +++ b/go/vt/vtadmin/cluster/cluster.go @@ -105,6 +105,10 @@ func New(ctx context.Context, cfg Config) (*Cluster, error) { return nil, fmt.Errorf("error creating vtsql connection config: %w", err) } + for _, opt := range cfg.vtsqlConfigOpts { + vtsqlCfg = opt(vtsqlCfg) + } + vtctldargs := buildPFlagSlice(cfg.VtctldFlags) vtctldCfg, err := vtctldclient.Parse(protocluster, disco, vtctldargs) @@ -116,7 +120,11 @@ func New(ctx context.Context, cfg Config) (*Cluster, error) { vtctldCfg = opt(vtctldCfg) } - cluster.DB = vtsql.New(vtsqlCfg) + cluster.DB, err = vtsql.New(ctx, vtsqlCfg) + if err != nil { + return nil, fmt.Errorf("error creating vtsql proxy: %w", err) + } + cluster.Vtctld, err = vtctldclient.New(ctx, vtctldCfg) if err != nil { return nil, fmt.Errorf("error creating vtctldclient: %w", err) @@ -883,10 +891,6 @@ func (c *Cluster) GetTablets(ctx context.Context) ([]*vtadminpb.Tablet, error) { } func (c *Cluster) getTablets(ctx context.Context) ([]*vtadminpb.Tablet, error) { - if err := c.DB.Dial(ctx, ""); err != nil { - return nil, err - } - rows, err := c.DB.ShowTablets(ctx) if err != nil { return nil, err diff --git a/go/vt/vtadmin/cluster/cluster_test.go b/go/vt/vtadmin/cluster/cluster_test.go index c61490f1609..17994a8d1c3 100644 --- a/go/vt/vtadmin/cluster/cluster_test.go +++ b/go/vt/vtadmin/cluster/cluster_test.go @@ -18,7 +18,6 @@ package cluster_test import ( "context" - "database/sql" "errors" "fmt" "strings" @@ -33,14 +32,10 @@ import ( "vitess.io/vitess/go/protoutil" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/topo" - "vitess.io/vitess/go/vt/vitessdriver" "vitess.io/vitess/go/vt/vtadmin/cluster" - "vitess.io/vitess/go/vt/vtadmin/cluster/discovery/fakediscovery" - "vitess.io/vitess/go/vt/vtadmin/cluster/resolver" vtadminerrors "vitess.io/vitess/go/vt/vtadmin/errors" "vitess.io/vitess/go/vt/vtadmin/testutil" "vitess.io/vitess/go/vt/vtadmin/vtctldclient/fakevtctldclient" - "vitess.io/vitess/go/vt/vtadmin/vtsql" "vitess.io/vitess/go/vt/vtctl/vtctldclient" replicationdatapb "vitess.io/vitess/go/vt/proto/replicationdata" @@ -2596,32 +2591,6 @@ func TestGetShardReplicationPositions(t *testing.T) { } } -// This test only validates the error handling on dialing database connections. -// Other cases are covered by one or both of TestFindTablets and TestFindTablet. -func TestGetTablets(t *testing.T) { - t.Parallel() - - disco := fakediscovery.New() - disco.AddTaggedGates(nil, &vtadminpb.VTGate{Hostname: "gate"}) - - db := vtsql.New(&vtsql.Config{ - Cluster: &vtadminpb.Cluster{ - Id: "c1", - Name: "one", - }, - ResolverOptions: &resolver.Options{ - Discovery: disco, - }, - }) - db.DialFunc = func(cfg vitessdriver.Configuration) (*sql.DB, error) { - return nil, assert.AnError - } - - c := &cluster.Cluster{DB: db} - _, err := c.GetTablets(context.Background()) - assert.Error(t, err) -} - func TestGetVSchema(t *testing.T) { t.Parallel() diff --git a/go/vt/vtadmin/cluster/config.go b/go/vt/vtadmin/cluster/config.go index fafbbab1764..eacdc2a4834 100644 --- a/go/vt/vtadmin/cluster/config.go +++ b/go/vt/vtadmin/cluster/config.go @@ -31,6 +31,7 @@ import ( "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/vtadmin/errors" "vitess.io/vitess/go/vt/vtadmin/vtctldclient" + "vitess.io/vitess/go/vt/vtadmin/vtsql" ) var ( @@ -65,6 +66,7 @@ type Config struct { WorkflowReadPoolConfig *RPCPoolConfig vtctldConfigOpts []vtctldclient.ConfigOption + vtsqlConfigOpts []vtsql.ConfigOption } // Cluster returns a new cluster instance from the given config. @@ -369,3 +371,13 @@ func (cfg Config) WithVtctldTestConfigOptions(opts ...vtctldclient.ConfigOption) cfg.vtctldConfigOpts = append(cfg.vtctldConfigOpts, opts...) return cfg } + +// WithVtSQLTestConfigOptions returns a new Config with the given vtsql +// ConfigOptions appended to any existing ConfigOptions in the current Config. +// +// It should be used in tests only, and is exported to for use in the +// vtadmin/testutil package. +func (cfg Config) WithVtSQLTestConfigOptions(opts ...vtsql.ConfigOption) Config { + cfg.vtsqlConfigOpts = append(cfg.vtsqlConfigOpts, opts...) + return cfg +} diff --git a/go/vt/vtadmin/testutil/cluster.go b/go/vt/vtadmin/testutil/cluster.go index b435b388864..38765774e7a 100644 --- a/go/vt/vtadmin/testutil/cluster.go +++ b/go/vt/vtadmin/testutil/cluster.go @@ -95,12 +95,25 @@ func BuildCluster(t testing.TB, cfg TestClusterConfig) *cluster.Cluster { disco.AddTaggedGates(nil, &vtadminpb.VTGate{Hostname: fmt.Sprintf("%s-%s-gate", cfg.Cluster.Name, cfg.Cluster.Id)}) disco.AddTaggedVtctlds(nil, &vtadminpb.Vtctld{Hostname: "doesn't matter"}) + tablets := make([]*vtadminpb.Tablet, len(cfg.Tablets)) + for i, t := range cfg.Tablets { + tablet := &vtadminpb.Tablet{ + Cluster: cfg.Cluster, + Tablet: t.Tablet, + State: t.State, + } + + tablets[i] = tablet + } + clusterConf := cluster.Config{ ID: cfg.Cluster.Id, Name: cfg.Cluster.Name, DiscoveryImpl: discoveryTestImplName, }.WithVtctldTestConfigOptions(vtadminvtctldclient.WithDialFunc(func(addr string, ff grpcclient.FailFast, opts ...grpc.DialOption) (vtctldclient.VtctldClient, error) { return cfg.VtctldClient, nil + })).WithVtSQLTestConfigOptions(vtsql.WithDialFunc(func(c vitessdriver.Configuration) (*sql.DB, error) { + return sql.OpenDB(&fakevtsql.Connector{Tablets: tablets, ShouldErr: cfg.DBConfig.ShouldErr}), nil })) m.Lock() @@ -113,22 +126,6 @@ func BuildCluster(t testing.TB, cfg TestClusterConfig) *cluster.Cluster { require.NoError(t, err, "failed to create cluster from configs %+v %+v", clusterConf, cfg) - tablets := make([]*vtadminpb.Tablet, len(cfg.Tablets)) - for i, t := range cfg.Tablets { - tablet := &vtadminpb.Tablet{ - Cluster: cfg.Cluster, - Tablet: t.Tablet, - State: t.State, - } - - tablets[i] = tablet - } - - db := c.DB.(*vtsql.VTGateProxy) - db.DialFunc = func(_ vitessdriver.Configuration) (*sql.DB, error) { - return sql.OpenDB(&fakevtsql.Connector{Tablets: tablets, ShouldErr: cfg.DBConfig.ShouldErr}), nil - } - return c } diff --git a/go/vt/vtadmin/vtsql/config.go b/go/vt/vtadmin/vtsql/config.go index ef324667bf0..bf3eee89fab 100644 --- a/go/vt/vtadmin/vtsql/config.go +++ b/go/vt/vtadmin/vtsql/config.go @@ -17,11 +17,13 @@ limitations under the License. package vtsql import ( + "database/sql" "fmt" "github.com/spf13/pflag" "vitess.io/vitess/go/vt/grpcclient" + "vitess.io/vitess/go/vt/vitessdriver" "vitess.io/vitess/go/vt/vtadmin/cluster/discovery" "vitess.io/vitess/go/vt/vtadmin/cluster/resolver" "vitess.io/vitess/go/vt/vtadmin/credentials" @@ -39,6 +41,24 @@ type Config struct { Cluster *vtadminpb.Cluster ResolverOptions *resolver.Options + + dialFunc func(c vitessdriver.Configuration) (*sql.DB, error) +} + +// ConfigOption is a function that mutates a Config. It should return the same +// Config structure, in a builder-pattern style. +type ConfigOption func(cfg *Config) *Config + +// WithDialFunc returns a ConfigOption that applies the given dial function to +// a Config. +// +// It is used to support dependency injection in tests, and needs to be exported +// for higher-level tests (for example, package vtadmin/cluster). +func WithDialFunc(f func(c vitessdriver.Configuration) (*sql.DB, error)) ConfigOption { + return func(cfg *Config) *Config { + cfg.dialFunc = f + return cfg + } } // Parse returns a new config with the given cluster ID and name, after diff --git a/go/vt/vtadmin/vtsql/vtsql.go b/go/vt/vtadmin/vtsql/vtsql.go index 48f5e5b9cb7..5aeef853c50 100644 --- a/go/vt/vtadmin/vtsql/vtsql.go +++ b/go/vt/vtadmin/vtsql/vtsql.go @@ -44,21 +44,15 @@ type DB interface { // ShowTablets executes `SHOW vitess_tablets` and returns the result. ShowTablets(ctx context.Context) (*sql.Rows, error) - // Dial opens a gRPC database connection to a vtgate in the cluster. If the - // DB already has a valid connection, this is a no-op. - // - // target is a Vitess query target, e.g. "", "", "@replica". - Dial(ctx context.Context, target string, opts ...grpc.DialOption) error - // Ping behaves like (*sql.DB).Ping. Ping() error // PingContext behaves like (*sql.DB).PingContext. PingContext(ctx context.Context) error - // Close closes the currently-held database connection. This is a no-op if + // Close closes the underlying database connection. This is a no-op if // the DB has no current valid connection. It is safe to call repeatedly. - // Users may call Dial on a previously-closed DB to create a new connection, - // but that connection may not be to the same particular vtgate. + // + // Once closed, a DB is not safe for reuse. Close() error } @@ -72,7 +66,7 @@ type VTGateProxy struct { // DialFunc is called to open a new database connection. In production this // should always be vitessdriver.OpenWithConfiguration, but it is exported // for testing purposes. - DialFunc func(cfg vitessdriver.Configuration) (*sql.DB, error) + dialFunc func(cfg vitessdriver.Configuration) (*sql.DB, error) resolver grpcresolver.Builder m sync.Mutex @@ -92,14 +86,25 @@ var ErrConnClosed = errors.New("use of closed connection") // // It does not open a connection to a vtgate; users must call Dial before first // use. -func New(cfg *Config) *VTGateProxy { - return &VTGateProxy{ +func New(ctx context.Context, cfg *Config) (*VTGateProxy, error) { + dialFunc := cfg.dialFunc + if dialFunc == nil { + dialFunc = vitessdriver.OpenWithConfiguration + } + + proxy := VTGateProxy{ cluster: cfg.Cluster, creds: cfg.Credentials, cfg: cfg, - DialFunc: vitessdriver.OpenWithConfiguration, + dialFunc: dialFunc, resolver: cfg.ResolverOptions.NewBuilder(cfg.Cluster.Id), } + + if err := proxy.dial(ctx, ""); err != nil { + return nil, err + } + + return &proxy, nil } // getQueryContext returns a new context with the correct effective and immediate @@ -123,23 +128,12 @@ func (vtgate *VTGateProxy) getQueryContext(ctx context.Context) context.Context // Dial is part of the DB interface. The proxy's DiscoveryTags can be set to // narrow the set of possible gates it will connect to. -func (vtgate *VTGateProxy) Dial(ctx context.Context, target string, opts ...grpc.DialOption) error { +func (vtgate *VTGateProxy) dial(ctx context.Context, target string, opts ...grpc.DialOption) error { span, _ := trace.NewSpan(ctx, "VTGateProxy.Dial") defer span.Finish() vtadminproto.AnnotateClusterSpan(vtgate.cluster, span) - - vtgate.m.Lock() - defer vtgate.m.Unlock() - - if vtgate.conn != nil { - log.Info("Have valid connection to vtgate, reusing it.") - span.Annotate("is_noop", true) - - return nil - } - - span.Annotate("is_noop", false) + span.Annotate("is_using_credentials", vtgate.creds != nil) conf := vitessdriver.Configuration{ Protocol: fmt.Sprintf("grpc_%s", vtgate.cluster.Id), @@ -154,11 +148,16 @@ func (vtgate *VTGateProxy) Dial(ctx context.Context, target string, opts ...grpc }, conf.GRPCDialOptions...) } - db, err := vtgate.DialFunc(conf) + db, err := vtgate.dialFunc(conf) if err != nil { return fmt.Errorf("error dialing vtgate: %w", err) } + log.Infof("Established gRPC connection to vtgate\n") + + vtgate.m.Lock() + defer vtgate.m.Unlock() + vtgate.conn = db vtgate.dialedAt = time.Now() @@ -207,18 +206,12 @@ func (vtgate *VTGateProxy) Close() error { vtgate.m.Lock() defer vtgate.m.Unlock() - return vtgate.closeLocked() -} - -func (vtgate *VTGateProxy) closeLocked() error { if vtgate.conn == nil { return nil } - err := vtgate.conn.Close() - vtgate.conn = nil - - return err + defer func() { vtgate.conn = nil }() + return vtgate.conn.Close() } // Debug implements debug.Debuggable for VTGateProxy. diff --git a/go/vt/vtadmin/vtsql/vtsql_test.go b/go/vt/vtadmin/vtsql/vtsql_test.go index 53b7ba03585..1f2465dec66 100644 --- a/go/vt/vtadmin/vtsql/vtsql_test.go +++ b/go/vt/vtadmin/vtsql/vtsql_test.go @@ -18,22 +18,15 @@ package vtsql import ( "context" - "database/sql" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/vt/callerid" "vitess.io/vitess/go/vt/grpcclient" - "vitess.io/vitess/go/vt/vitessdriver" - "vitess.io/vitess/go/vt/vtadmin/cluster/discovery/fakediscovery" - "vitess.io/vitess/go/vt/vtadmin/cluster/resolver" - "vitess.io/vitess/go/vt/vtadmin/vtsql/fakevtsql" querypb "vitess.io/vitess/go/vt/proto/query" - vtadminpb "vitess.io/vitess/go/vt/proto/vtadmin" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" ) @@ -87,91 +80,3 @@ func Test_getQueryContext(t *testing.T) { assertEffectiveCaller(t, callerid.EffectiveCallerIDFromContext(outctx), "efuser", "vtadmin", "") assertImmediateCaller(t, callerid.ImmediateCallerIDFromContext(outctx), "imuser") } - -func TestDial(t *testing.T) { - t.Helper() - - tests := []struct { - name string - disco *fakediscovery.Fake - gates []*vtadminpb.VTGate - proxy *VTGateProxy - dialer func(cfg vitessdriver.Configuration) (*sql.DB, error) - shouldErr bool - }{ - { - name: "existing conn", - proxy: &VTGateProxy{ - cluster: &vtadminpb.Cluster{}, - conn: sql.OpenDB(&fakevtsql.Connector{}), - }, - shouldErr: false, - }, - { - name: "dialer error", - disco: fakediscovery.New(), - gates: []*vtadminpb.VTGate{ - { - Hostname: "gate", - }, - }, - proxy: &VTGateProxy{ - cluster: &vtadminpb.Cluster{Id: "test"}, - DialFunc: func(cfg vitessdriver.Configuration) (*sql.DB, error) { - return nil, assert.AnError - }, - }, - shouldErr: true, - }, - { - name: "success", - disco: fakediscovery.New(), - gates: []*vtadminpb.VTGate{ - { - Hostname: "gate", - }, - }, - proxy: &VTGateProxy{ - cluster: &vtadminpb.Cluster{Id: "test"}, - creds: &StaticAuthCredentials{ - StaticAuthClientCreds: &grpcclient.StaticAuthClientCreds{ - Username: "user", - Password: "pass", - }, - }, - DialFunc: func(cfg vitessdriver.Configuration) (*sql.DB, error) { - return sql.OpenDB(&fakevtsql.Connector{}), nil - }, - }, - }, - } - - ctx := context.Background() - - for _, tt := range tests { - tt := tt - - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - if tt.disco != nil { - if len(tt.gates) > 0 { - tt.disco.AddTaggedGates(nil, tt.gates...) - } - } - - tt.proxy.resolver = (&resolver.Options{ - Discovery: tt.disco, - DiscoveryTimeout: 50 * time.Millisecond, - }).NewBuilder(tt.proxy.cluster.Id) - - err := tt.proxy.Dial(ctx, "") - if tt.shouldErr { - assert.Error(t, err) - return - } - - assert.NoError(t, err) - }) - } -} diff --git a/go/vt/vtgate/vtgateconn/vtgateconn.go b/go/vt/vtgate/vtgateconn/vtgateconn.go index 6483526aabb..cbef891bf0e 100644 --- a/go/vt/vtgate/vtgateconn/vtgateconn.go +++ b/go/vt/vtgate/vtgateconn/vtgateconn.go @@ -20,6 +20,7 @@ import ( "context" "flag" "fmt" + "sync" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" @@ -170,11 +171,17 @@ type Impl interface { // object that can communicate with a VTGate. type DialerFunc func(ctx context.Context, address string) (Impl, error) -var dialers = make(map[string]DialerFunc) +var ( + dialers = make(map[string]DialerFunc) + dialersM sync.Mutex +) // RegisterDialer is meant to be used by Dialer implementations // to self register. func RegisterDialer(name string, dialer DialerFunc) { + dialersM.Lock() + defer dialersM.Unlock() + if _, ok := dialers[name]; ok { log.Warningf("Dialer %s already exists, overwriting it", name) } @@ -183,7 +190,10 @@ func RegisterDialer(name string, dialer DialerFunc) { // DialProtocol dials a specific protocol, and returns the *VTGateConn func DialProtocol(ctx context.Context, protocol string, address string) (*VTGateConn, error) { + dialersM.Lock() dialer, ok := dialers[protocol] + dialersM.Unlock() + if !ok { return nil, fmt.Errorf("no dialer registered for VTGate protocol %s", protocol) }