Skip to content

Commit 0d69f22

Browse files
committed
feat: support statement_timeout and transaction_timeout property
Add a statement_timeout connection property that is used as the default timeout for all statements that are executed on a connection. If both a context timeout and a statement_timeout is specified, then the lower of the two will be used. Also adds a transaction_timeout property that is additionally used for all statements in a read/write transaction. The deadline of the transaction is calculated at the start of the transaction, and all statements in the transaction get this deadline, unless the statement already has an earlier deadline from for example a statement_timeout or a context deadline. Fixes #574 Fixes #575
1 parent ad05fde commit 0d69f22

14 files changed

+393
-30
lines changed

conn.go

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,51 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831831
return &stmt{conn: c, query: parsedSQL, numArgs: len(args), execOptions: execOptions}, nil
832832
}
833833

834+
func (c *conn) addStatementAndTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc, error) {
835+
var statementDeadline time.Time
836+
var transactionDeadline time.Time
837+
var deadline time.Time
838+
var hasStatementDeadline bool
839+
var hasTransactionDeadline bool
840+
841+
statementTimeout := propertyStatementTimeout.GetValueOrDefault(c.state)
842+
if statementTimeout != time.Duration(0) {
843+
hasStatementDeadline = true
844+
statementDeadline = time.Now().Add(statementTimeout)
845+
}
846+
transactionDeadline, hasTransactionDeadline, err := c.transactionDeadline()
847+
if err != nil {
848+
return nil, nil, err
849+
}
850+
if !hasStatementDeadline && !hasTransactionDeadline {
851+
return ctx, func() {}, nil
852+
}
853+
if hasTransactionDeadline && hasStatementDeadline {
854+
if statementDeadline.Before(transactionDeadline) {
855+
deadline = statementDeadline
856+
} else {
857+
deadline = transactionDeadline
858+
}
859+
} else if hasStatementDeadline {
860+
deadline = statementDeadline
861+
} else {
862+
deadline = transactionDeadline
863+
}
864+
newCtx, cancel := context.WithDeadline(ctx, deadline)
865+
return newCtx, cancel, nil
866+
}
867+
868+
func (c *conn) transactionDeadline() (time.Time, bool, error) {
869+
if c.tx == nil {
870+
return time.Time{}, false, nil
871+
}
872+
if err := c.tx.ensureActivated(); err != nil {
873+
return time.Time{}, false, err
874+
}
875+
deadline, hasDeadline := c.tx.deadline()
876+
return deadline, hasDeadline, nil
877+
}
878+
834879
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
835880
// Execute client side statement if it is one.
836881
clientStmt, err := c.parser.ParseClientSideStatement(query)
@@ -849,13 +894,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849894
return c.queryContext(ctx, query, execOptions, args)
850895
}
851896

852-
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
897+
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
898+
ctx, cancel, err := c.addStatementAndTransactionTimeout(ctx)
899+
if err != nil {
900+
return nil, err
901+
}
902+
defer func() {
903+
if returnedErr != nil {
904+
cancel()
905+
}
906+
}()
853907
// Clear the commit timestamp of this connection before we execute the query.
854908
c.clearCommitResponse()
855909
// Check if the execution options contains an instruction to execute
856910
// a specific partition of a PartitionedQuery.
857911
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
858-
return pq.execute(ctx, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
912+
return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
859913
}
860914

861915
stmt, err := prepareSpannerStmt(c.parser, query, args)
@@ -869,7 +923,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869923
if err != nil {
870924
return nil, err
871925
}
872-
return createDriverResultRows(res, execOptions), nil
926+
return createDriverResultRows(res, cancel, execOptions), nil
873927
}
874928
var iter rowIterator
875929
if c.tx == nil {
@@ -884,7 +938,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884938
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
885939
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions"))
886940
} else if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
887-
return c.executeAutoPartitionedQuery(ctx, query, execOptions, args)
941+
return c.executeAutoPartitionedQuery(ctx, cancel, query, execOptions, args)
888942
} else {
889943
// The statement was either detected as being a query, or potentially not recognized at all.
890944
// In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,14 +947,15 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893947
}
894948
} else {
895949
if execOptions.PartitionedQueryOptions.PartitionQuery {
950+
defer cancel()
896951
return c.tx.partitionQuery(ctx, stmt, execOptions)
897952
}
898953
iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions)
899954
if err != nil {
900955
return nil, err
901956
}
902957
}
903-
res := createRows(iter, execOptions)
958+
res := createRows(iter, cancel, execOptions)
904959
if execOptions.DirectExecuteQuery {
905960
// This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906961
res.getColumns()
@@ -929,7 +984,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
929984
return c.execContext(ctx, query, execOptions, args)
930985
}
931986

932-
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Result, error) {
987+
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedResult driver.Result, returnedErr error) {
988+
ctx, cancel, err := c.addStatementAndTransactionTimeout(ctx)
989+
if err != nil {
990+
return nil, err
991+
}
992+
defer func() {
993+
if returnedErr != nil {
994+
cancel()
995+
}
996+
}()
933997
// Clear the commit timestamp of this connection before we execute the statement.
934998
c.clearCommitResponse()
935999

@@ -1041,6 +1105,17 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411105
return noTransaction()
10421106
}
10431107
c.tx = c.prevTx
1108+
// If the aborted error happened during the Commit, then the transaction context has been cancelled,
1109+
// and we need to create a new one.
1110+
if rwTx, ok := c.tx.contextTransaction.(*readWriteTransaction); ok {
1111+
newCtx, cancel := c.addTransactionTimeout(c.tx.ctx)
1112+
rwTx.ctx = newCtx
1113+
origClose := rwTx.close
1114+
rwTx.close = func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
1115+
origClose(result, commitResponse, commitErr)
1116+
cancel()
1117+
}
1118+
}
10441119
c.resetForRetry = true
10451120
} else if c.tx == nil {
10461121
return noTransaction()
@@ -1248,6 +1323,14 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481323
return c.tx, nil
12491324
}
12501325

1326+
func (c *conn) addTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
1327+
timeout := propertyTransactionTimeout.GetValueOrDefault(c.state)
1328+
if timeout == time.Duration(0) {
1329+
return ctx, func() {}
1330+
}
1331+
return context.WithTimeout(ctx, timeout)
1332+
}
1333+
12511334
func (c *conn) activateTransaction() (contextTransaction, error) {
12521335
closeFunc := c.tx.close
12531336
if propertyTransactionReadOnly.GetValueOrDefault(c.state) {
@@ -1283,19 +1366,21 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831366
opts := spanner.TransactionOptions{}
12841367
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))
12851368

1286-
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(c.tx.ctx, c.client, opts, func() spanner.TransactionOptions {
1369+
ctx, cancel := c.addTransactionTimeout(c.tx.ctx)
1370+
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
12871371
defer func() {
12881372
// Reset the transaction_tag after starting the transaction.
12891373
_ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser)
12901374
}()
12911375
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true))
12921376
})
12931377
if err != nil {
1378+
cancel()
12941379
return nil, err
12951380
}
12961381
logger := c.logger.With("tx", "rw")
12971382
return &readWriteTransaction{
1298-
ctx: c.tx.ctx,
1383+
ctx: ctx,
12991384
conn: c,
13001385
logger: logger,
13011386
rwTx: tx,
@@ -1307,6 +1392,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071392
} else {
13081393
closeFunc(txResultRollback)
13091394
}
1395+
cancel()
13101396
},
13111397
retryAborts: sync.OnceValue(func() bool {
13121398
return c.RetryAbortsInternally()
@@ -1371,18 +1457,20 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711457
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
13721458
}
13731459

1374-
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
1460+
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.CancelFunc, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
13751461
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true, Isolation: withBatchReadOnly(driver.IsolationLevel(sql.LevelDefault))})
13761462
if err != nil {
13771463
return nil, err
13781464
}
13791465
r, err := c.queryContext(ctx, query, execOptions, args)
13801466
if err != nil {
13811467
_ = tx.Rollback()
1468+
cancel()
13821469
return nil, err
13831470
}
13841471
if rows, ok := r.(*rows); ok {
13851472
rows.close = func() error {
1473+
defer cancel()
13861474
return tx.Commit()
13871475
}
13881476
}

connection_leak_test.go

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,23 @@ import (
2424

2525
"cloud.google.com/go/spanner"
2626
"github.com/googleapis/go-sql-spanner/testutil"
27+
"go.uber.org/goleak"
2728
"google.golang.org/grpc/codes"
2829
gstatus "google.golang.org/grpc/status"
2930
)
3031

3132
func TestNoLeak(t *testing.T) {
32-
t.Parallel()
33+
// Not parallel, as it checks for leaked goroutines.
3334

34-
db, server, teardown := setupTestDBConnection(t)
35+
db, server, teardown := setupTestDBConnectionWithParams(t, "statement_timeout=10s;transaction_timeout=20s")
3536
defer teardown()
3637
// Set MaxOpenConns to 1 to force an error if anything leaks a connection.
3738
db.SetMaxOpenConns(1)
3839

3940
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
4041
defer cancel()
4142

42-
for i := 0; i < 2; i++ {
43+
runTests := func() {
4344
pingContext(ctx, t, db)
4445
pingFailed(ctx, t, server, db)
4546
simpleQuery(ctx, t, db)
@@ -50,8 +51,28 @@ func TestNoLeak(t *testing.T) {
5051
readOnlyTxWithStaleness(ctx, t, db)
5152
simpleReadWriteTx(ctx, t, db)
5253
runTransactionRetry(ctx, t, server, db)
54+
runTransactionRetryAbortedHalfway(ctx, t, server, db)
5355
readOnlyTxWithOptions(ctx, t, db)
5456
}
57+
58+
for i := 0; i < 2; i++ {
59+
runTests()
60+
}
61+
ignoreCurrent := goleak.IgnoreCurrent()
62+
63+
for i := 0; i < 10; i++ {
64+
runTests()
65+
}
66+
goleak.VerifyNone(t, ignoreCurrent,
67+
goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*healthChecker).worker"),
68+
goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*healthChecker).multiplexSessionWorker"),
69+
goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*healthChecker).maintainer"),
70+
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*controlBuffer).get"),
71+
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/transport.(*http2Server).keepalive"),
72+
goleak.IgnoreTopFunction("internal/poll.runtime_pollWait"),
73+
goleak.IgnoreTopFunction("google.golang.org/grpc/internal/grpcsync.(*CallbackSerializer).run"),
74+
goleak.IgnoreTopFunction("cloud.google.com/go/spanner.(*sessionPool).createMultiplexedSession"),
75+
)
5576
}
5677

5778
func pingContext(ctx context.Context, t *testing.T, db *sql.DB) {
@@ -308,6 +329,50 @@ func runTransactionRetry(ctx context.Context, t *testing.T, server *testutil.Moc
308329
}
309330
}
310331

332+
func runTransactionRetryAbortedHalfway(ctx context.Context, t *testing.T, server *testutil.MockedSpannerInMemTestServer, db *sql.DB) {
333+
var attempts int
334+
err := RunTransactionWithOptions(ctx, db, &sql.TxOptions{}, func(ctx context.Context, tx *sql.Tx) error {
335+
attempts++
336+
rows, err := tx.QueryContext(ctx, testutil.SelectFooFromBar, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_1"}})
337+
if err != nil {
338+
t.Fatal(err)
339+
}
340+
for rows.Next() {
341+
}
342+
if err := rows.Close(); err != nil {
343+
t.Fatal(err)
344+
}
345+
346+
if attempts == 1 {
347+
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{
348+
Errors: []error{gstatus.Error(codes.Aborted, "Aborted")},
349+
})
350+
}
351+
if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo, ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_2"}}); err != nil {
352+
return err
353+
}
354+
355+
if attempts == 2 {
356+
server.TestSpanner.PutExecutionTime(testutil.MethodExecuteBatchDml, testutil.SimulatedExecutionTime{
357+
Errors: []error{gstatus.Error(codes.Aborted, "Aborted")},
358+
})
359+
}
360+
if _, err := tx.ExecContext(ctx, "start batch dml", ExecOptions{QueryOptions: spanner.QueryOptions{RequestTag: "tag_3"}}); err != nil {
361+
return err
362+
}
363+
if _, err := tx.ExecContext(ctx, testutil.UpdateBarSetFoo); err != nil {
364+
return err
365+
}
366+
if _, err := tx.ExecContext(ctx, "run batch"); err != nil {
367+
return err
368+
}
369+
return nil
370+
}, spanner.TransactionOptions{TransactionTag: "my_transaction_tag"})
371+
if err != nil {
372+
t.Fatalf("failed to run transaction: %v", err)
373+
}
374+
}
375+
311376
func readOnlyTxWithOptions(ctx context.Context, t *testing.T, db *sql.DB) {
312377
tx, err := BeginReadOnlyTransaction(ctx, db,
313378
ReadOnlyTransactionOptions{TimestampBound: spanner.ExactStaleness(10 * time.Second)})

connection_properties.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,16 @@ var propertyTransactionBatchReadOnly = createConnectionProperty(
305305
connectionstate.ContextUser,
306306
connectionstate.ConvertBool,
307307
)
308+
var propertyTransactionTimeout = createConnectionProperty(
309+
"transaction_timeout",
310+
"The timeout to apply to all read/write transactions on this connection. "+
311+
"Setting the timeout to zero means no timeout.",
312+
time.Duration(0),
313+
false,
314+
nil,
315+
connectionstate.ContextUser,
316+
connectionstate.ConvertDuration,
317+
)
308318

309319
// ------------------------------------------------------------------------------------------------
310320
// Statement connection properties.
@@ -318,6 +328,17 @@ var propertyStatementTag = createConnectionProperty(
318328
connectionstate.ContextUser,
319329
connectionstate.ConvertString,
320330
)
331+
var propertyStatementTimeout = createConnectionProperty(
332+
"statement_timeout",
333+
"The timeout to apply to all statements on this connection. "+
334+
"Setting the timeout to zero means no timeout. "+
335+
"Any existing context deadline will take precedence over this statement timeout.",
336+
time.Duration(0),
337+
false,
338+
nil,
339+
connectionstate.ContextUser,
340+
connectionstate.ConvertDuration,
341+
)
321342

322343
// ------------------------------------------------------------------------------------------------
323344
// Startup connection properties.

driver.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,6 +1022,9 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti
10221022
return nil, err
10231023
}
10241024
for {
1025+
// Create a derived context for each new attempt
1026+
// ctx, cancel := context.WithCancel(ctx)
1027+
10251028
err = protected(ctx, tx, f)
10261029
errDuringCommit := false
10271030
if err == nil {

driver_with_mockserver_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4572,9 +4572,9 @@ func TestRunTransaction(t *testing.T) {
45724572
// Verify that internal retries are disabled during RunTransaction
45734573
txi := reflect.ValueOf(tx).Elem().FieldByName("txi")
45744574
delegatingTx := (*delegatingTransaction)(txi.Elem().UnsafePointer())
4575-
rwTx := delegatingTx.contextTransaction.(*readWriteTransaction)
4575+
rwTx, ok := delegatingTx.contextTransaction.(*readWriteTransaction)
45764576
// Verify that getting the transaction through reflection worked.
4577-
if g, w := rwTx.ctx, ctx; g != w {
4577+
if !ok {
45784578
return fmt.Errorf("getting the transaction through reflection failed")
45794579
}
45804580
if rwTx.retryAborts() {
@@ -5034,9 +5034,9 @@ func TestBeginReadWriteTransaction(t *testing.T) {
50345034
// Verify that internal retries are disabled during this transaction.
50355035
txi := reflect.ValueOf(tx).Elem().FieldByName("txi")
50365036
delegatingTx := (*delegatingTransaction)(txi.Elem().UnsafePointer())
5037-
rwTx := delegatingTx.contextTransaction.(*readWriteTransaction)
5037+
rwTx, ok := delegatingTx.contextTransaction.(*readWriteTransaction)
50385038
// Verify that getting the transaction through reflection worked.
5039-
if g, w := rwTx.ctx, ctx; g != w {
5039+
if !ok {
50405040
t.Fatal("getting the transaction through reflection failed")
50415041
}
50425042
if rwTx.retryAborts() {

0 commit comments

Comments
 (0)