Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ type connectResult struct {
// FIXME(alainjobart) once we have more of a server side, add test cases
// to cover all failure scenarios.
func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
if params.ConnectTimeoutMs != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(params.ConnectTimeoutMs)*time.Millisecond)
defer cancel()
}
netProto := "tcp"
addr := ""
if params.UnixSocket != "" {
Expand Down
10 changes: 10 additions & 0 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ func TestConnectTimeout(t *testing.T) {
t.Errorf("Was expecting context.DeadlineExceeded but got: %v", err)
}

// Tests a connection timeout through params
ctx = context.Background()
paramsWithTimeout := params
paramsWithTimeout.ConnectTimeoutMs = 1
_, err = Connect(ctx, paramsWithTimeout)
cancel()
if err != context.DeadlineExceeded {
t.Errorf("Was expecting context.DeadlineExceeded but got: %v", err)
}

// Now the server will listen, but close all connections on accept.
wg := sync.WaitGroup{}
wg.Add(1)
Expand Down
6 changes: 3 additions & 3 deletions go/pools/resource_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ var (
)

// Factory is a function that can be used to create a resource.
type Factory func() (Resource, error)
type Factory func(context.Context) (Resource, error)

// Resource defines the interface that every resource must provide.
// Thread synchronization between Close() and IsClosed()
Expand Down Expand Up @@ -228,7 +228,7 @@ func (rp *ResourcePool) get(ctx context.Context) (resource Resource, err error)
// Unwrap
if wrapper.resource == nil {
span, _ := trace.NewSpan(ctx, "ResourcePool.factory")
wrapper.resource, err = rp.factory()
wrapper.resource, err = rp.factory(ctx)
span.Finish()
if err != nil {
rp.resources <- resourceWrapper{}
Expand Down Expand Up @@ -267,7 +267,7 @@ func (rp *ResourcePool) Put(resource Resource) {
}

func (rp *ResourcePool) reopenResource(wrapper *resourceWrapper) {
if r, err := rp.factory(); err == nil {
if r, err := rp.factory(context.TODO()); err == nil {
wrapper.resource = r
wrapper.timeUsed = time.Now()
} else {
Expand Down
6 changes: 3 additions & 3 deletions go/pools/resource_pool_flaky_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@ func logWait(start time.Time) {
waitStarts = append(waitStarts, start)
}

func PoolFactory() (Resource, error) {
func PoolFactory(ctx context.Context) (Resource, error) {
count.Add(1)
return &TestResource{lastID.Add(1), false}, nil
}

func FailFactory() (Resource, error) {
func FailFactory(ctx context.Context) (Resource, error) {
return nil, errors.New("Failed")
}

func SlowFailFactory() (Resource, error) {
func SlowFailFactory(ctx context.Context) (Resource, error) {
time.Sleep(10 * time.Millisecond)
return nil, errors.New("Failed")
}
Expand Down
13 changes: 1 addition & 12 deletions go/vt/dbconnpool/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package dbconnpool
import (
"context"
"fmt"
"time"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/sqltypes"
Expand All @@ -35,17 +34,7 @@ type DBConnection struct {

// NewDBConnection returns a new DBConnection based on the ConnParams
// and will use the provided stats to collect timing.
func NewDBConnection(info dbconfigs.Connector) (*DBConnection, error) {
ctx := context.Background()
params, err := info.MysqlParams()
if err != nil {
return nil, err
}
if params.ConnectTimeoutMs != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(params.ConnectTimeoutMs)*time.Millisecond)
defer cancel()
}
func NewDBConnection(ctx context.Context, info dbconfigs.Connector) (*DBConnection, error) {
c, err := info.Connect(ctx)
if err != nil {
return nil, err
Expand Down
6 changes: 3 additions & 3 deletions go/vt/dbconnpool/connection_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ func (cp *ConnectionPool) Open(info dbconfigs.Connector) {
}

// connect is used by the resource pool to create a new Resource.
func (cp *ConnectionPool) connect() (pools.Resource, error) {
c, err := NewDBConnection(cp.info)
func (cp *ConnectionPool) connect(ctx context.Context) (pools.Resource, error) {
c, err := NewDBConnection(ctx, cp.info)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -216,7 +216,7 @@ func (cp *ConnectionPool) Get(ctx context.Context) (*PooledDBConnection, error)
if cp.resolutionFrequency > 0 &&
cp.hostIsNotIP &&
!cp.validAddress(net.ParseIP(r.(*PooledDBConnection).RemoteAddr().String())) {
err := r.(*PooledDBConnection).Reconnect()
err := r.(*PooledDBConnection).Reconnect(ctx)
if err != nil {
p.Put(r)
return nil, err
Expand Down
6 changes: 4 additions & 2 deletions go/vt/dbconnpool/pooled_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.

package dbconnpool

import "context"

// PooledDBConnection re-exposes DBConnection to be used by ConnectionPool.
type PooledDBConnection struct {
*DBConnection
Expand All @@ -33,9 +35,9 @@ func (pc *PooledDBConnection) Recycle() {

// Reconnect replaces the existing underlying connection with a new one,
// if possible. Recycle should still be called afterwards.
func (pc *PooledDBConnection) Reconnect() error {
func (pc *PooledDBConnection) Reconnect(ctx context.Context) error {
pc.DBConnection.Close()
newConn, err := NewDBConnection(pc.pool.info)
newConn, err := NewDBConnection(ctx, pc.pool.info)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/mysqlctl/fakemysqldaemon/fakemysqldaemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -472,12 +472,12 @@ func (fmd *FakeMysqlDaemon) GetAppConnection(ctx context.Context) (*dbconnpool.P

// GetDbaConnection is part of the MysqlDaemon interface.
func (fmd *FakeMysqlDaemon) GetDbaConnection() (*dbconnpool.DBConnection, error) {
return dbconnpool.NewDBConnection(fmd.db.ConnParams())
return dbconnpool.NewDBConnection(context.Background(), fmd.db.ConnParams())
}

// GetAllPrivsConnection is part of the MysqlDaemon interface.
func (fmd *FakeMysqlDaemon) GetAllPrivsConnection() (*dbconnpool.DBConnection, error) {
return dbconnpool.NewDBConnection(fmd.db.ConnParams())
return dbconnpool.NewDBConnection(context.Background(), fmd.db.ConnParams())
}

// SetSemiSyncEnabled is part of the MysqlDaemon interface.
Expand Down
4 changes: 2 additions & 2 deletions go/vt/mysqlctl/mysqld.go
Original file line number Diff line number Diff line change
Expand Up @@ -1097,12 +1097,12 @@ func (mysqld *Mysqld) GetAppConnection(ctx context.Context) (*dbconnpool.PooledD

// GetDbaConnection creates a new DBConnection.
func (mysqld *Mysqld) GetDbaConnection() (*dbconnpool.DBConnection, error) {
return dbconnpool.NewDBConnection(mysqld.dbcfgs.Dba())
return dbconnpool.NewDBConnection(context.TODO(), mysqld.dbcfgs.Dba())
}

// GetAllPrivsConnection creates a new DBConnection.
func (mysqld *Mysqld) GetAllPrivsConnection() (*dbconnpool.DBConnection, error) {
return dbconnpool.NewDBConnection(mysqld.dbcfgs.AllPrivsWithDB())
return dbconnpool.NewDBConnection(context.TODO(), mysqld.dbcfgs.AllPrivsWithDB())
}

// Close will close this instance of Mysqld. It will wait for all dba
Expand Down
2 changes: 1 addition & 1 deletion go/vt/mysqlctl/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func getPoolReconnect(ctx context.Context, pool *dbconnpool.ConnectionPool) (*db
if _, err := conn.ExecuteFetch("SELECT 1", 1, false); err != nil {
// If we get a connection error, try to reconnect.
if sqlErr, ok := err.(*mysql.SQLError); ok && (sqlErr.Number() == mysql.CRServerGone || sqlErr.Number() == mysql.CRServerLost) {
if err := conn.Reconnect(); err != nil {
if err := conn.Reconnect(ctx); err != nil {
conn.Recycle()
return nil, err
}
Expand Down
69 changes: 0 additions & 69 deletions go/vt/vttablet/endtoend/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ func TestConfigVars(t *testing.T) {
}, {
tag: "QueryTimeout",
val: int(tabletenv.Config.QueryTimeout * 1e9),
}, {
tag: "QueryPoolTimeout",
val: int(tabletenv.Config.QueryPoolTimeout * 1e9),
}, {
tag: "SchemaReloadTime",
val: int(tabletenv.Config.SchemaReloadTime * 1e9),
Expand Down Expand Up @@ -111,9 +108,6 @@ func TestConfigVars(t *testing.T) {
}, {
tag: "TransactionPoolMaxCap",
val: tabletenv.Config.TransactionCap,
}, {
tag: "TransactionPoolTimeout",
val: int(tabletenv.Config.TxPoolTimeout * 1e9),
}, {
tag: "TransactionTimeout",
val: int(tabletenv.Config.TransactionTimeout * 1e9),
Expand Down Expand Up @@ -388,69 +382,6 @@ func TestQueryTimeout(t *testing.T) {
}
}

func TestQueryPoolTimeout(t *testing.T) {
vstart := framework.DebugVars()

defer framework.Server.SetQueryPoolTimeout(framework.Server.GetQueryPoolTimeout())
framework.Server.SetQueryPoolTimeout(100 * time.Millisecond)
defer framework.Server.SetPoolSize(framework.Server.PoolSize())
framework.Server.SetPoolSize(1)

client := framework.NewClient()

ch := make(chan error)
go func() {
_, qerr := framework.NewClient().Execute("select sleep(0.5) from dual", nil)
ch <- qerr
}()
// The queries have to be different so consolidator doesn't kick in.
go func() {
_, qerr := framework.NewClient().Execute("select sleep(0.49) from dual", nil)
ch <- qerr
}()

err1 := <-ch
err2 := <-ch

if err1 == nil && err2 == nil {
t.Errorf("both queries unexpectedly succeeded")
}
if err1 != nil && err2 != nil {
t.Errorf("both queries unexpectedly failed")
}

var err error
if err1 != nil {
err = err1
} else {
err = err2
}

if code := vterrors.Code(err); code != vtrpcpb.Code_RESOURCE_EXHAUSTED {
t.Errorf("Error code: %v, want %v", code, vtrpcpb.Code_RESOURCE_EXHAUSTED)
}

// Test that this doesn't override the query timeout
defer framework.Server.QueryTimeout.Set(framework.Server.QueryTimeout.Get())
framework.Server.QueryTimeout.Set(100 * time.Millisecond)

_, err = client.Execute("select sleep(1) from vitess_test", nil)
if code := vterrors.Code(err); code != vtrpcpb.Code_CANCELED {
t.Errorf("Error code: %v, want %v", code, vtrpcpb.Code_CANCELED)
}

vend := framework.DebugVars()
if err := verifyIntValue(vend, "QueryPoolTimeout", int(100*time.Millisecond)); err != nil {
t.Error(err)
}
if err := verifyIntValue(vend, "QueryTimeout", int(100*time.Millisecond)); err != nil {
t.Error(err)
}
if err := compareIntDiff(vend, "Kills/Queries", vstart, 1); err != nil {
t.Error(err)
}
}

func TestConnPoolWaitCap(t *testing.T) {
vstart := framework.DebugVars()

Expand Down
88 changes: 0 additions & 88 deletions go/vt/vttablet/endtoend/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"testing"
"time"

Expand All @@ -29,14 +28,12 @@ import (
"golang.org/x/net/context"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vttablet/endtoend/framework"
"vitess.io/vitess/go/vt/vttablet/tabletserver"
"vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv"

querypb "vitess.io/vitess/go/vt/proto/query"
topodatapb "vitess.io/vitess/go/vt/proto/topodata"
vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
)

func TestCommit(t *testing.T) {
Expand Down Expand Up @@ -339,50 +336,6 @@ func TestTxPoolSize(t *testing.T) {
}
}

func TestTxTimeout(t *testing.T) {
vstart := framework.DebugVars()

defer framework.Server.SetTxTimeout(framework.Server.TxTimeout())
txTimeout := 1 * time.Millisecond
framework.Server.SetTxTimeout(txTimeout)
if err := verifyIntValue(framework.DebugVars(), "TransactionTimeout", int(txTimeout)); err != nil {
t.Error(err)
}

defer framework.Server.SetTxPoolTimeout(framework.Server.TxPoolTimeout())
txPoolTimeout := 2 * time.Millisecond
framework.Server.SetTxPoolTimeout(txPoolTimeout)
if err := verifyIntValue(framework.DebugVars(), "TransactionPoolTimeout", int(txPoolTimeout)); err != nil {
t.Error(err)
}

catcher := framework.NewTxCatcher()
defer catcher.Close()
client := framework.NewClient()
err := client.Begin(false)
if err != nil {
t.Error(err)
return
}
tx, err := catcher.Next()
if err != nil {
t.Error(err)
return
}
if tx.Conclusion != "kill" {
t.Errorf("Conclusion: %s, want kill", tx.Conclusion)
}
if err := compareIntDiff(framework.DebugVars(), "Kills/Transactions", vstart, 1); err != nil {
t.Error(err)
}

// Ensure commit fails.
err = client.Commit()
if code := vterrors.Code(err); code != vtrpcpb.Code_ABORTED {
t.Errorf("Commit code: %v, want %v", code, vtrpcpb.Code_ABORTED)
}
}

func TestForUpdate(t *testing.T) {
for _, mode := range []string{"for update", "lock in share mode"} {
client := framework.NewClient()
Expand Down Expand Up @@ -813,44 +766,3 @@ func TestManualTwopcz(t *testing.T) {
fmt.Print("Sleeping for 30 seconds\n")
time.Sleep(30 * time.Second)
}

func TestTransactionPoolResourceWaitTime(t *testing.T) {
defer framework.Server.SetPoolSize(framework.Server.TxPoolSize())
defer framework.Server.SetTxPoolTimeout(framework.Server.TxPoolTimeout())
framework.Server.SetTxPoolSize(1)
framework.Server.SetTxPoolTimeout(10 * time.Second)
debugVarPath := "Waits/Histograms/TransactionPoolResourceWaitTime/Count"

for sleep := 0.1; sleep < 10.0; sleep *= 2 {
vstart := framework.DebugVars()
var wg sync.WaitGroup
wg.Add(2)

transactionFunc := func() {
client := framework.NewClient()

bv := map[string]*querypb.BindVariable{}
query := fmt.Sprintf("select sleep(%v) from dual", sleep)
if _, err := client.BeginExecute(query, bv); err != nil {
t.Error(err)
return
}
if err := client.Rollback(); err != nil {
t.Error(err)
return
}
wg.Done()
}
go transactionFunc()
go transactionFunc()
wg.Wait()
vend := framework.DebugVars()
if err := compareIntDiff(vend, debugVarPath, vstart, 1); err != nil {
t.Logf("DebugVars %v not incremented with sleep=%v", debugVarPath, sleep)
continue
}
t.Logf("DebugVars %v properly incremented with sleep=%v", debugVarPath, sleep)
return
}
t.Errorf("DebugVars %v not incremented", debugVarPath)
}
Loading