Skip to content

Commit 1791559

Browse files
committed
feat: support statement_timeout and transaction_timeout property
Add a statement_timeout connection property that is used as the default timeout for the execution of all statements that are executed on a connection. The timeout is only used for the actual execution, and not attached to the iterator that is returned for a query. This also means that a query that is executed without the DirectExecuteQuery option, will ignore the statement_timeout value. Also adds a transaction_timeout property that is additionally used for all statements in a read/write transaction. The deadline of the transaction is calculated at the start of the transaction, and all statements in the transaction get this deadline, unless the statement already has an earlier deadline from for example a statement_timeout or a context deadline. This change also fixes some issues with deadlines when using the gRPC API of SpannerLib. The context that is used for an RPC invocation is cancelled after the RPC has finished. This context should therefore not be used as the context for any query execution, as the context is attached to the row iterator, and would cancel the query execution halfway. Fixes #574 Fixes #575
1 parent ad05fde commit 1791559

17 files changed

+692
-67
lines changed

conn.go

Lines changed: 174 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"database/sql"
2020
"database/sql/driver"
2121
"errors"
22+
"fmt"
2223
"log/slog"
2324
"slices"
2425
"sync"
@@ -831,6 +832,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831832
return &stmt{conn: c, query: parsedSQL, numArgs: len(args), execOptions: execOptions}, nil
832833
}
833834

835+
// Adds any statement or transaction timeout to the given context. The deadline of the returned
836+
// context will be the earliest of:
837+
// 1. Any existing deadline on the input context.
838+
// 2. Any existing transaction deadline.
839+
// 3. A deadline calculated from the current time + the value of statement_timeout.
840+
func (c *conn) addStatementAndTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc, error) {
841+
var statementDeadline time.Time
842+
var transactionDeadline time.Time
843+
var deadline time.Time
844+
var hasStatementDeadline bool
845+
var hasTransactionDeadline bool
846+
847+
// Check if the connection has a value for statement_timeout.
848+
statementTimeout := propertyStatementTimeout.GetValueOrDefault(c.state)
849+
if statementTimeout != time.Duration(0) {
850+
hasStatementDeadline = true
851+
statementDeadline = time.Now().Add(statementTimeout)
852+
}
853+
// Check if the current transaction has a deadline.
854+
transactionDeadline, hasTransactionDeadline, err := c.transactionDeadline()
855+
if err != nil {
856+
return nil, nil, err
857+
}
858+
859+
// If there is no statement_timeout and no current transaction deadline,
860+
// then can just use the input context as-is.
861+
if !hasStatementDeadline && !hasTransactionDeadline {
862+
return ctx, func() {}, nil
863+
}
864+
865+
// If there is both a transaction and a statement deadline, then we use the earliest
866+
// of those two.
867+
if hasTransactionDeadline && hasStatementDeadline {
868+
if statementDeadline.Before(transactionDeadline) {
869+
deadline = statementDeadline
870+
} else {
871+
deadline = transactionDeadline
872+
}
873+
} else if hasStatementDeadline {
874+
deadline = statementDeadline
875+
} else {
876+
deadline = transactionDeadline
877+
}
878+
// context.WithDeadline automatically selects the earliest deadline of
879+
// the existing deadline on the context and the given deadline.
880+
newCtx, cancel := context.WithDeadline(ctx, deadline)
881+
return newCtx, cancel, nil
882+
}
883+
884+
// transactionDeadline returns the deadline of the current transaction
885+
// on the connection. This also activates the transaction if it is not
886+
// yet activated.
887+
func (c *conn) transactionDeadline() (time.Time, bool, error) {
888+
if c.tx == nil {
889+
return time.Time{}, false, nil
890+
}
891+
if err := c.tx.ensureActivated(); err != nil {
892+
return time.Time{}, false, err
893+
}
894+
deadline, hasDeadline := c.tx.deadline()
895+
return deadline, hasDeadline, nil
896+
}
897+
834898
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
835899
// Execute client side statement if it is one.
836900
clientStmt, err := c.parser.ParseClientSideStatement(query)
@@ -849,13 +913,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849913
return c.queryContext(ctx, query, execOptions, args)
850914
}
851915

852-
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
916+
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
917+
ctx, cancelCause := context.WithCancelCause(ctx)
918+
cancel := func() {
919+
cancelCause(nil)
920+
}
921+
defer func() {
922+
if returnedErr != nil {
923+
cancel()
924+
}
925+
}()
853926
// Clear the commit timestamp of this connection before we execute the query.
854927
c.clearCommitResponse()
855928
// Check if the execution options contains an instruction to execute
856929
// a specific partition of a PartitionedQuery.
857930
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
858-
return pq.execute(ctx, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
931+
return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
859932
}
860933

861934
stmt, err := prepareSpannerStmt(c.parser, query, args)
@@ -869,7 +942,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869942
if err != nil {
870943
return nil, err
871944
}
872-
return createDriverResultRows(res, execOptions), nil
945+
return createDriverResultRows(res, cancel, execOptions), nil
873946
}
874947
var iter rowIterator
875948
if c.tx == nil {
@@ -884,7 +957,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884957
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
885958
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions"))
886959
} else if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
887-
return c.executeAutoPartitionedQuery(ctx, query, execOptions, args)
960+
return c.executeAutoPartitionedQuery(ctx, cancel, query, execOptions, args)
888961
} else {
889962
// The statement was either detected as being a query, or potentially not recognized at all.
890963
// In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,25 +966,71 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893966
}
894967
} else {
895968
if execOptions.PartitionedQueryOptions.PartitionQuery {
969+
// The driver.Rows instance that is returned for partitionQuery does not
970+
// contain a context, and therefore also does not cancel the context when it is closed.
971+
defer cancel()
896972
return c.tx.partitionQuery(ctx, stmt, execOptions)
897973
}
898974
iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions)
899975
if err != nil {
900976
return nil, err
901977
}
902978
}
903-
res := createRows(iter, execOptions)
979+
res := createRows(iter, cancel, execOptions)
904980
if execOptions.DirectExecuteQuery {
905-
// This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906-
res.getColumns()
907-
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
908-
_ = res.Close()
909-
return nil, res.dirtyErr
981+
if err := c.directExecuteQuery(ctx, cancelCause, res, execOptions); err != nil {
982+
return nil, err
910983
}
911984
}
912985
return res, nil
913986
}
914987

988+
func (c *conn) directExecuteQuery(ctx context.Context, cancelQuery context.CancelCauseFunc, res *rows, execOptions *ExecOptions) error {
989+
statementCtx := ctx
990+
if execOptions.DirectExecuteQuery && execOptions.DirectExecuteContext != nil {
991+
statementCtx = execOptions.DirectExecuteContext
992+
}
993+
// Add the statement or transaction deadline to the context.
994+
statementCtx, cancelStatement, err := c.addStatementAndTransactionTimeout(statementCtx)
995+
if err != nil {
996+
return err
997+
}
998+
defer cancelStatement()
999+
1000+
// Asynchronously fetch the first partial result set from Spanner.
1001+
done := make(chan struct{})
1002+
go func() {
1003+
defer close(done)
1004+
res.getColumns()
1005+
}()
1006+
// Wait until either the done channel is closed or the context is done.
1007+
var statementErr error
1008+
select {
1009+
case <-statementCtx.Done():
1010+
statementErr = statementCtx.Err()
1011+
// Cancel the execution.
1012+
cancelQuery(statementCtx.Err())
1013+
case <-done:
1014+
}
1015+
1016+
// Now wait until done channel is closed. This could be because the execution finished
1017+
// successfully, or because the context was cancelled, which again causes the execution
1018+
// to (eventually) fail.
1019+
<-done
1020+
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
1021+
_ = res.Close()
1022+
if statementErr != nil {
1023+
// Create a status error from the statement error and wrap both the Spanner error and the status error into
1024+
// one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request
1025+
// ID from the Spanner error.
1026+
s := status.FromContextError(statementErr)
1027+
return fmt.Errorf("%w: %w", s.Err(), res.dirtyErr)
1028+
}
1029+
return res.dirtyErr
1030+
}
1031+
return nil
1032+
}
1033+
9151034
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
9161035
// Execute client side statement if it is one.
9171036
stmt, err := c.parser.ParseClientSideStatement(query)
@@ -929,7 +1048,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
9291048
return c.execContext(ctx, query, execOptions, args)
9301049
}
9311050

932-
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Result, error) {
1051+
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedResult driver.Result, returnedErr error) {
1052+
// Add the statement/transaction deadline to the context.
1053+
ctx, cancel, err := c.addStatementAndTransactionTimeout(ctx)
1054+
if err != nil {
1055+
return nil, err
1056+
}
1057+
defer cancel()
9331058
// Clear the commit timestamp of this connection before we execute the statement.
9341059
c.clearCommitResponse()
9351060

@@ -1041,6 +1166,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411166
return noTransaction()
10421167
}
10431168
c.tx = c.prevTx
1169+
// If the aborted error happened during the Commit, then the transaction
1170+
// context has been cancelled, and we need to create a new one.
1171+
if rwTx, ok := c.tx.contextTransaction.(*readWriteTransaction); ok {
1172+
newCtx, cancel := c.addTransactionTimeout(c.tx.ctx)
1173+
rwTx.ctx = newCtx
1174+
// Make sure that we cancel the new context when the transaction is closed.
1175+
origClose := rwTx.close
1176+
rwTx.close = func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
1177+
origClose(result, commitResponse, commitErr)
1178+
cancel()
1179+
}
1180+
}
10441181
c.resetForRetry = true
10451182
} else if c.tx == nil {
10461183
return noTransaction()
@@ -1248,6 +1385,15 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481385
return c.tx, nil
12491386
}
12501387

1388+
// addTransactionTimeout adds the current transaction_timeout to the given context.
1389+
func (c *conn) addTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
1390+
timeout := propertyTransactionTimeout.GetValueOrDefault(c.state)
1391+
if timeout == time.Duration(0) {
1392+
return ctx, func() {}
1393+
}
1394+
return context.WithTimeout(ctx, timeout)
1395+
}
1396+
12511397
func (c *conn) activateTransaction() (contextTransaction, error) {
12521398
closeFunc := c.tx.close
12531399
if propertyTransactionReadOnly.GetValueOrDefault(c.state) {
@@ -1283,19 +1429,23 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831429
opts := spanner.TransactionOptions{}
12841430
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))
12851431

1286-
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(c.tx.ctx, c.client, opts, func() spanner.TransactionOptions {
1432+
// Add the current value of transaction_timeout to the context that is registered
1433+
// on the transaction.
1434+
ctx, cancel := c.addTransactionTimeout(c.tx.ctx)
1435+
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
12871436
defer func() {
12881437
// Reset the transaction_tag after starting the transaction.
12891438
_ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser)
12901439
}()
12911440
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true))
12921441
})
12931442
if err != nil {
1443+
cancel()
12941444
return nil, err
12951445
}
12961446
logger := c.logger.With("tx", "rw")
12971447
return &readWriteTransaction{
1298-
ctx: c.tx.ctx,
1448+
ctx: ctx,
12991449
conn: c,
13001450
logger: logger,
13011451
rwTx: tx,
@@ -1307,6 +1457,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071457
} else {
13081458
closeFunc(txResultRollback)
13091459
}
1460+
cancel()
13101461
},
13111462
retryAborts: sync.OnceValue(func() bool {
13121463
return c.RetryAbortsInternally()
@@ -1371,7 +1522,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711522
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
13721523
}
13731524

1374-
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
1525+
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.CancelFunc, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
1526+
// The cancel() function is called by the returned Rows object when it is closed.
1527+
// However, if an error is returned instead of a Rows instance, we need to cancel
1528+
// the context when we return from this function.
1529+
defer func() {
1530+
if returnedErr != nil {
1531+
cancel()
1532+
}
1533+
}()
13751534
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true, Isolation: withBatchReadOnly(driver.IsolationLevel(sql.LevelDefault))})
13761535
if err != nil {
13771536
return nil, err
@@ -1383,6 +1542,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
13831542
}
13841543
if rows, ok := r.(*rows); ok {
13851544
rows.close = func() error {
1545+
defer cancel()
13861546
return tx.Commit()
13871547
}
13881548
}

0 commit comments

Comments
 (0)