Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ type DB struct {
name string
// isConnFail trigger a panic in the connection handler.
isConnFail bool
// connDelay causes a sleep in the connection handler
connDelay time.Duration
// shouldClose, if true, tells ComQuery() to close the connection when
// processing the next query. This will trigger a MySQL client error with
// errno 2013 ("server lost").
Expand Down Expand Up @@ -288,6 +290,10 @@ func (db *DB) NewConnection(c *mysql.Conn) {
panic(fmt.Errorf("simulating a connection failure"))
}

if db.connDelay != 0 {
time.Sleep(db.connDelay)
}

if conn, ok := db.connections[c.ConnectionID]; ok {
db.t.Fatalf("BUG: connection with id: %v is already active. existing conn: %v new conn: %v", c.ConnectionID, conn, c)
}
Expand Down Expand Up @@ -517,6 +523,13 @@ func (db *DB) DisableConnFail() {
db.isConnFail = false
}

// SetConnDelay delays connections to this fake DB for the given duration
func (db *DB) SetConnDelay(d time.Duration) {
db.mu.Lock()
defer db.mu.Unlock()
db.connDelay = d
}

// EnableShouldClose closes the connection when processing the next query.
func (db *DB) EnableShouldClose() {
db.mu.Lock()
Expand Down
5 changes: 5 additions & 0 deletions go/vt/dbconnpool/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,16 @@ func (dbc *DBConnection) ExecuteStreamFetch(query string, callback func(*sqltype
// NewDBConnection returns a new DBConnection based on the ConnParams
// and will use the provided stats to collect timing.
func NewDBConnection(info *mysql.ConnParams, mysqlStats *stats.Timings) (*DBConnection, error) {
start := time.Now()
defer mysqlStats.Record("Connect", start)
params, err := dbconfigs.WithCredentials(info)
if err != nil {
return nil, err
}
ctx := context.Background()
c, err := mysql.Connect(ctx, &params)
if err != nil {
mysqlStats.Record("ConnectError", start)
}
return &DBConnection{c, mysqlStats}, err
}
4 changes: 2 additions & 2 deletions go/vt/vtqueryserver/endtoend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@ func testFetch(t *testing.T, conn *mysql.Conn, sql string, expectedRows int) *sq
func testDML(t *testing.T, conn *mysql.Conn, sql string, expectedNumQueries int64, expectedRowsAffected uint64) {
t.Helper()

numQueries := tabletenv.MySQLStats.Count()
numQueries := tabletenv.MySQLStats.Counts()["Exec"]
result, err := conn.ExecuteFetch(sql, 1000, false)
if err != nil {
t.Errorf("error: %v", err)
}
numQueries = tabletenv.MySQLStats.Count() - numQueries
numQueries = tabletenv.MySQLStats.Counts()["Exec"] - numQueries

if numQueries != expectedNumQueries {
t.Errorf("expected %d mysql queries but got %d", expectedNumQueries, numQueries)
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vttablet/tabletserver/connpool/dbconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ func (dbc *DBConn) execOnce(ctx context.Context, query string, maxrows int, want
dbc.current.Set(query)
defer dbc.current.Set("")

// Check if the context is already past its deadline before
// trying to execute the query.
select {
case <-ctx.Done():
return nil, fmt.Errorf("%v before execution started", ctx.Err())
default:
}

done, wg := dbc.setDeadline(ctx)
if done != nil {
defer func() {
Expand Down
116 changes: 115 additions & 1 deletion go/vt/vttablet/tabletserver/connpool/dbconn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,23 @@ import (
"vitess.io/vitess/go/sqltypes"

querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv"
)

func compareTimingCounts(t *testing.T, op string, delta int64, before, after map[string]int64) {
t.Helper()
countBefore := before[op]
countAfter := after[op]
if countAfter-countBefore != delta {
t.Errorf("Expected %s to increase by %d, got %d (%d => %d)", op, delta, countAfter-countBefore, countBefore, countAfter)
}
}

func TestDBConnExec(t *testing.T) {
db := fakesqldb.New(t)
defer db.Close()
startCounts := tabletenv.MySQLStats.Counts()

sql := "select * from test_table limit 1000"
expectedResult := &sqltypes.Result{
Fields: []*querypb.Field{
Expand Down Expand Up @@ -66,7 +78,13 @@ func TestDBConnExec(t *testing.T) {
if !reflect.DeepEqual(expectedResult, result) {
t.Errorf("Exec: %v, want %v", expectedResult, result)
}
// Exec fail

compareTimingCounts(t, "Connect", 1, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "Exec", 1, startCounts, tabletenv.MySQLStats.Counts())

startCounts = tabletenv.MySQLStats.Counts()

// Exec fail due to client side error
db.AddRejectedQuery(sql, &mysql.SQLError{
Num: 2012,
Message: "connection fail",
Expand All @@ -77,6 +95,102 @@ func TestDBConnExec(t *testing.T) {
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("Exec: %v, want %s", err, want)
}

// The client side error triggers a retry in exec.
compareTimingCounts(t, "Connect", 1, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "Exec", 2, startCounts, tabletenv.MySQLStats.Counts())

startCounts = tabletenv.MySQLStats.Counts()

// Set the connection fail flag and and try again.
// This time the initial query fails as does the reconnect attempt.
db.EnableConnFail()
_, err = dbConn.Exec(ctx, sql, 1, false)
want = "packet read failed"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("Exec: %v, want %s", err, want)
}
db.DisableConnFail()

compareTimingCounts(t, "Connect", 1, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "ConnectError", 1, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "Exec", 1, startCounts, tabletenv.MySQLStats.Counts())
}

func TestDBConnDeadline(t *testing.T) {
db := fakesqldb.New(t)
defer db.Close()
startCounts := tabletenv.MySQLStats.Counts()
sql := "select * from test_table limit 1000"
expectedResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Type: sqltypes.VarChar},
},
RowsAffected: 1,
Rows: [][]sqltypes.Value{
{sqltypes.NewVarChar("123")},
},
}
db.AddQuery(sql, expectedResult)

connPool := newPool()
connPool.Open(db.ConnParams(), db.ConnParams(), db.ConnParams())
defer connPool.Close()

db.SetConnDelay(100 * time.Millisecond)
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(50*time.Millisecond))
defer cancel()

dbConn, err := NewDBConn(connPool, db.ConnParams())
defer dbConn.Close()
if err != nil {
t.Fatalf("should not get an error, err: %v", err)
}

_, err = dbConn.Exec(ctx, sql, 1, false)
want := "context deadline exceeded before execution started"
if err == nil || !strings.Contains(err.Error(), want) {
t.Errorf("Exec: %v, want %s", err, want)
}

compareTimingCounts(t, "Connect", 1, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "ConnectError", 0, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "Exec", 0, startCounts, tabletenv.MySQLStats.Counts())

startCounts = tabletenv.MySQLStats.Counts()

ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(10*time.Second))
defer cancel()

result, err := dbConn.Exec(ctx, sql, 1, false)
if err != nil {
t.Fatalf("should not get an error, err: %v", err)
}
expectedResult.Fields = nil
if !reflect.DeepEqual(expectedResult, result) {
t.Errorf("Exec: %v, want %v", expectedResult, result)
}

compareTimingCounts(t, "Connect", 0, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "ConnectError", 0, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "Exec", 1, startCounts, tabletenv.MySQLStats.Counts())

startCounts = tabletenv.MySQLStats.Counts()

// Test with just the background context (with no deadline)
result, err = dbConn.Exec(context.Background(), sql, 1, false)
if err != nil {
t.Fatalf("should not get an error, err: %v", err)
}
expectedResult.Fields = nil
if !reflect.DeepEqual(expectedResult, result) {
t.Errorf("Exec: %v, want %v", expectedResult, result)
}

compareTimingCounts(t, "Connect", 0, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "ConnectError", 0, startCounts, tabletenv.MySQLStats.Counts())
compareTimingCounts(t, "Exec", 1, startCounts, tabletenv.MySQLStats.Counts())

}

func TestDBConnKill(t *testing.T) {
Expand Down