diff --git a/go/vt/vitessdriver/driver.go b/go/vt/vitessdriver/driver.go index 63c00ab730e..66f6e52528b 100644 --- a/go/vt/vitessdriver/driver.go +++ b/go/vt/vitessdriver/driver.go @@ -209,6 +209,15 @@ func (c *conn) dial() error { return nil } +func (c *conn) Ping(ctx context.Context) error { + if c.Streaming { + return errors.New("Ping not allowed for streaming connections") + } + + _, err := c.ExecContext(ctx, "select 1", nil) + return err +} + func (c *conn) Prepare(query string) (driver.Stmt, error) { return &stmt{c: c, query: query}, nil } diff --git a/go/vt/vtadmin/vtsql/config.go b/go/vt/vtadmin/vtsql/config.go index 285925c5959..8655bfd7c9f 100644 --- a/go/vt/vtadmin/vtsql/config.go +++ b/go/vt/vtadmin/vtsql/config.go @@ -18,6 +18,7 @@ package vtsql import ( "fmt" + "time" "github.com/spf13/pflag" @@ -34,6 +35,8 @@ type Config struct { DiscoveryTags []string Credentials Credentials + DialPingTimeout time.Duration + // CredentialsPath is used only to power vtadmin debug endpoints; there may // be a better way where we don't need to put this in the config, because // it's not really an "option" in normal use. @@ -65,6 +68,8 @@ func Parse(cluster *vtadminpb.Cluster, disco discovery.Discovery, args []string) func (c *Config) Parse(args []string) error { fs := pflag.NewFlagSet("", pflag.ContinueOnError) + fs.DurationVar(&c.DialPingTimeout, "dial-ping-timeout", time.Millisecond*500, + "Timeout to use when pinging an existing connection during calls to Dial.") fs.StringSliceVar(&c.DiscoveryTags, "discovery-tags", []string{}, "repeated, comma-separated list of tags to use when discovering a vtgate to connect to. "+ "the semantics of the tags may depend on the specific discovery implementation used") diff --git a/go/vt/vtadmin/vtsql/config_test.go b/go/vt/vtadmin/vtsql/config_test.go index e32b6855a4f..84f877f234a 100644 --- a/go/vt/vtadmin/vtsql/config_test.go +++ b/go/vt/vtadmin/vtsql/config_test.go @@ -23,6 +23,7 @@ import ( "path/filepath" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -143,6 +144,7 @@ func TestConfigParse(t *testing.T) { Id: "cid", Name: "testcluster", }, + DialPingTimeout: time.Millisecond * 500, DiscoveryTags: expectedTags, Credentials: expectedCreds, CredentialsPath: path, diff --git a/go/vt/vtadmin/vtsql/vtsql.go b/go/vt/vtadmin/vtsql/vtsql.go index 6a1b9b3c82c..c5fe327ec49 100644 --- a/go/vt/vtadmin/vtsql/vtsql.go +++ b/go/vt/vtadmin/vtsql/vtsql.go @@ -21,11 +21,13 @@ import ( "database/sql" "errors" "fmt" + "time" "google.golang.org/grpc" "vitess.io/vitess/go/trace" "vitess.io/vitess/go/vt/callerid" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/vitessdriver" "vitess.io/vitess/go/vt/vtadmin/cluster/discovery" "vitess.io/vitess/go/vt/vtadmin/vtadminproto" @@ -71,7 +73,8 @@ type VTGateProxy struct { // DialFunc is called to open a new database connection. In production this // should always be vitessdriver.OpenWithConfiguration, but it is exported // for testing purposes. - DialFunc func(cfg vitessdriver.Configuration) (*sql.DB, error) + DialFunc func(cfg vitessdriver.Configuration) (*sql.DB, error) + dialPingTimeout time.Duration host string conn *sql.DB @@ -96,11 +99,12 @@ func New(cfg *Config) *VTGateProxy { } return &VTGateProxy{ - cluster: cfg.Cluster, - discovery: cfg.Discovery, - discoveryTags: discoveryTags, - creds: cfg.Credentials, - DialFunc: vitessdriver.OpenWithConfiguration, + cluster: cfg.Cluster, + discovery: cfg.Discovery, + discoveryTags: discoveryTags, + creds: cfg.Credentials, + DialFunc: vitessdriver.OpenWithConfiguration, + dialPingTimeout: cfg.DialPingTimeout, } } @@ -132,13 +136,27 @@ func (vtgate *VTGateProxy) Dial(ctx context.Context, target string, opts ...grpc vtgate.annotateSpan(span) if vtgate.conn != nil { - span.Annotate("is_noop", true) - - // (TODO:@amason): consider a quick Ping() check in this case, and get a - // new connection if that fails. - return nil + ctx, cancel := context.WithTimeout(ctx, vtgate.dialPingTimeout) + defer cancel() + + err := vtgate.PingContext(ctx) + switch err { + case nil: + log.Infof("Have valid connection to %s, reusing it.", vtgate.host) + span.Annotate("is_noop", true) + + return nil + default: + log.Warningf("Ping failed on host %s: %s; Rediscovering a vtgate to get new connection", vtgate.host, err) + + if err := vtgate.Close(); err != nil { + log.Warningf("Error when closing connection to vtgate %s: %s; Continuing anyway ...", vtgate.host, err) + } + } } + span.Annotate("is_noop", false) + if vtgate.host == "" { gate, err := vtgate.discovery.DiscoverVTGateAddr(ctx, vtgate.discoveryTags) if err != nil { @@ -150,6 +168,8 @@ func (vtgate *VTGateProxy) Dial(ctx context.Context, target string, opts ...grpc span.Annotate("vtgate_host", gate) } + log.Infof("Dialing %s ...", vtgate.host) + conf := vitessdriver.Configuration{ Protocol: fmt.Sprintf("grpc_%s", vtgate.cluster.Id), Address: vtgate.host, diff --git a/go/vt/vtadmin/vtsql/vtsql_test.go b/go/vt/vtadmin/vtsql/vtsql_test.go index 17c1b503398..c0ddd84e170 100644 --- a/go/vt/vtadmin/vtsql/vtsql_test.go +++ b/go/vt/vtadmin/vtsql/vtsql_test.go @@ -20,6 +20,7 @@ import ( "context" "database/sql" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -100,8 +101,9 @@ func TestDial(t *testing.T) { { name: "existing conn", proxy: &VTGateProxy{ - cluster: &vtadminpb.Cluster{}, - conn: sql.OpenDB(&fakevtsql.Connector{}), + cluster: &vtadminpb.Cluster{}, + conn: sql.OpenDB(&fakevtsql.Connector{}), + dialPingTimeout: time.Millisecond * 10, }, shouldErr: false, },