Skip to content

Commit b0c75c5

Browse files
olavloiteaakashanandg
authored andcommitted
feat: support statement_timeout and transaction_timeout property (googleapis#578)
* feat: support statement_timeout and transaction_timeout property Add a statement_timeout connection property that is used as the default timeout for the execution of all statements that are executed on a connection. The timeout is only used for the actual execution, and not attached to the iterator that is returned for a query. This also means that a query that is executed without the DirectExecuteQuery option, will ignore the statement_timeout value. Also adds a transaction_timeout property that is additionally used for all statements in a read/write transaction. The deadline of the transaction is calculated at the start of the transaction, and all statements in the transaction get this deadline, unless the statement already has an earlier deadline from for example a statement_timeout or a context deadline. This change also fixes some issues with deadlines when using the gRPC API of SpannerLib. The context that is used for an RPC invocation is cancelled after the RPC has finished. This context should therefore not be used as the context for any query execution, as the context is attached to the row iterator, and would cancel the query execution halfway. Fixes googleapis#574 Fixes googleapis#575 * chore: use errors.Join instead of two %w verbs
1 parent 5ce2f0a commit b0c75c5

17 files changed

+698
-67
lines changed

conn.go

Lines changed: 179 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831831
return &stmt{conn: c, query: parsedSQL, numArgs: len(args), execOptions: execOptions}, nil
832832
}
833833

834+
// Adds any statement or transaction timeout to the given context. The deadline of the returned
835+
// context will be the earliest of:
836+
// 1. Any existing deadline on the input context.
837+
// 2. Any existing transaction deadline.
838+
// 3. A deadline calculated from the current time + the value of statement_timeout.
839+
func (c *conn) addStatementAndTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc, error) {
840+
var statementDeadline time.Time
841+
var transactionDeadline time.Time
842+
var deadline time.Time
843+
var hasStatementDeadline bool
844+
var hasTransactionDeadline bool
845+
846+
// Check if the connection has a value for statement_timeout.
847+
statementTimeout := propertyStatementTimeout.GetValueOrDefault(c.state)
848+
if statementTimeout != time.Duration(0) {
849+
hasStatementDeadline = true
850+
statementDeadline = time.Now().Add(statementTimeout)
851+
}
852+
// Check if the current transaction has a deadline.
853+
transactionDeadline, hasTransactionDeadline, err := c.transactionDeadline()
854+
if err != nil {
855+
return nil, nil, err
856+
}
857+
858+
// If there is no statement_timeout and no current transaction deadline,
859+
// then can just use the input context as-is.
860+
if !hasStatementDeadline && !hasTransactionDeadline {
861+
return ctx, func() {}, nil
862+
}
863+
864+
// If there is both a transaction and a statement deadline, then we use the earliest
865+
// of those two.
866+
if hasTransactionDeadline && hasStatementDeadline {
867+
if statementDeadline.Before(transactionDeadline) {
868+
deadline = statementDeadline
869+
} else {
870+
deadline = transactionDeadline
871+
}
872+
} else if hasStatementDeadline {
873+
deadline = statementDeadline
874+
} else {
875+
deadline = transactionDeadline
876+
}
877+
// context.WithDeadline automatically selects the earliest deadline of
878+
// the existing deadline on the context and the given deadline.
879+
newCtx, cancel := context.WithDeadline(ctx, deadline)
880+
return newCtx, cancel, nil
881+
}
882+
883+
// transactionDeadline returns the deadline of the current transaction
884+
// on the connection. This also activates the transaction if it is not
885+
// yet activated.
886+
func (c *conn) transactionDeadline() (time.Time, bool, error) {
887+
if c.tx == nil {
888+
return time.Time{}, false, nil
889+
}
890+
if err := c.tx.ensureActivated(); err != nil {
891+
return time.Time{}, false, err
892+
}
893+
deadline, hasDeadline := c.tx.deadline()
894+
return deadline, hasDeadline, nil
895+
}
896+
834897
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
835898
// Execute client side statement if it is one.
836899
clientStmt, err := c.parser.ParseClientSideStatement(query)
@@ -849,13 +912,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849912
return c.queryContext(ctx, query, execOptions, args)
850913
}
851914

852-
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
915+
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
916+
ctx, cancelCause := context.WithCancelCause(ctx)
917+
cancel := func() {
918+
cancelCause(nil)
919+
}
920+
defer func() {
921+
if returnedErr != nil {
922+
cancel()
923+
}
924+
}()
853925
// Clear the commit timestamp of this connection before we execute the query.
854926
c.clearCommitResponse()
855927
// Check if the execution options contains an instruction to execute
856928
// a specific partition of a PartitionedQuery.
857929
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
858-
return pq.execute(ctx, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
930+
return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
859931
}
860932

861933
stmt, err := prepareSpannerStmt(c.parser, query, args)
@@ -869,7 +941,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869941
if err != nil {
870942
return nil, err
871943
}
872-
return createDriverResultRows(res, execOptions), nil
944+
return createDriverResultRows(res, cancel, execOptions), nil
873945
}
874946
var iter rowIterator
875947
if c.tx == nil {
@@ -884,7 +956,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884956
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
885957
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions"))
886958
} else if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
887-
return c.executeAutoPartitionedQuery(ctx, query, execOptions, args)
959+
return c.executeAutoPartitionedQuery(ctx, cancel, query, execOptions, args)
888960
} else {
889961
// The statement was either detected as being a query, or potentially not recognized at all.
890962
// In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,25 +965,75 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893965
}
894966
} else {
895967
if execOptions.PartitionedQueryOptions.PartitionQuery {
968+
// The driver.Rows instance that is returned for partitionQuery does not
969+
// contain a context, and therefore also does not cancel the context when it is closed.
970+
defer cancel()
896971
return c.tx.partitionQuery(ctx, stmt, execOptions)
897972
}
898973
iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions)
899974
if err != nil {
900975
return nil, err
901976
}
902977
}
903-
res := createRows(iter, execOptions)
978+
res := createRows(iter, cancel, execOptions)
904979
if execOptions.DirectExecuteQuery {
905-
// This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906-
res.getColumns()
907-
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
908-
_ = res.Close()
909-
return nil, res.dirtyErr
980+
if err := c.directExecuteQuery(ctx, cancelCause, res, execOptions); err != nil {
981+
return nil, err
910982
}
911983
}
912984
return res, nil
913985
}
914986

987+
// directExecuteQuery blocks until the first PartialResultSet has been returned by Spanner. Any statement_timeout and/or
988+
// transaction_timeout is used while waiting for the first result to be returned.
989+
func (c *conn) directExecuteQuery(ctx context.Context, cancelQuery context.CancelCauseFunc, res *rows, execOptions *ExecOptions) error {
990+
statementCtx := ctx
991+
if execOptions.DirectExecuteContext != nil {
992+
statementCtx = execOptions.DirectExecuteContext
993+
}
994+
// Add the statement or transaction deadline to the context.
995+
statementCtx, cancelStatement, err := c.addStatementAndTransactionTimeout(statementCtx)
996+
if err != nil {
997+
return err
998+
}
999+
defer cancelStatement()
1000+
1001+
// Asynchronously fetch the first partial result set from Spanner.
1002+
done := make(chan struct{})
1003+
go func() {
1004+
// Calling res.getColumns() ensures that the first PartialResultSet has been returned, as it contains the
1005+
// metadata of the query.
1006+
defer close(done)
1007+
res.getColumns()
1008+
}()
1009+
// Wait until either the done channel is closed or the context is done.
1010+
var statementErr error
1011+
select {
1012+
case <-statementCtx.Done():
1013+
statementErr = statementCtx.Err()
1014+
// Cancel the query execution.
1015+
cancelQuery(statementCtx.Err())
1016+
case <-done:
1017+
}
1018+
1019+
// Now wait until done channel is closed. This could be because the execution finished
1020+
// successfully, or because the context was cancelled, which again causes the execution
1021+
// to (eventually) fail.
1022+
<-done
1023+
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
1024+
_ = res.Close()
1025+
if statementErr != nil {
1026+
// Create a status error from the statement error and wrap both the Spanner error and the status error into
1027+
// one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request
1028+
// ID from the Spanner error.
1029+
s := status.FromContextError(statementErr)
1030+
return errors.Join(s.Err(), res.dirtyErr)
1031+
}
1032+
return res.dirtyErr
1033+
}
1034+
return nil
1035+
}
1036+
9151037
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
9161038
// Execute client side statement if it is one.
9171039
stmt, err := c.parser.ParseClientSideStatement(query)
@@ -929,7 +1051,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
9291051
return c.execContext(ctx, query, execOptions, args)
9301052
}
9311053

932-
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Result, error) {
1054+
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedResult driver.Result, returnedErr error) {
1055+
// Add the statement/transaction deadline to the context.
1056+
ctx, cancel, err := c.addStatementAndTransactionTimeout(ctx)
1057+
if err != nil {
1058+
return nil, err
1059+
}
1060+
defer cancel()
9331061
// Clear the commit timestamp of this connection before we execute the statement.
9341062
c.clearCommitResponse()
9351063

@@ -1041,6 +1169,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411169
return noTransaction()
10421170
}
10431171
c.tx = c.prevTx
1172+
// If the aborted error happened during the Commit, then the transaction
1173+
// context has been cancelled, and we need to create a new one.
1174+
if rwTx, ok := c.tx.contextTransaction.(*readWriteTransaction); ok {
1175+
newCtx, cancel := c.addTransactionTimeout(c.tx.ctx)
1176+
rwTx.ctx = newCtx
1177+
// Make sure that we cancel the new context when the transaction is closed.
1178+
origClose := rwTx.close
1179+
rwTx.close = func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
1180+
origClose(result, commitResponse, commitErr)
1181+
cancel()
1182+
}
1183+
}
10441184
c.resetForRetry = true
10451185
} else if c.tx == nil {
10461186
return noTransaction()
@@ -1248,6 +1388,17 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481388
return c.tx, nil
12491389
}
12501390

1391+
// addTransactionTimeout creates a new derived context with the current transaction_timeout.
1392+
func (c *conn) addTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
1393+
timeout := propertyTransactionTimeout.GetValueOrDefault(c.state)
1394+
if timeout == time.Duration(0) {
1395+
return ctx, func() {}
1396+
}
1397+
// Note that this will set the actual deadline to the earliest of the existing deadline on ctx and the calculated
1398+
// deadline based on the timeout.
1399+
return context.WithTimeout(ctx, timeout)
1400+
}
1401+
12511402
func (c *conn) activateTransaction() (contextTransaction, error) {
12521403
closeFunc := c.tx.close
12531404
if propertyTransactionReadOnly.GetValueOrDefault(c.state) {
@@ -1283,19 +1434,23 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831434
opts := spanner.TransactionOptions{}
12841435
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))
12851436

1286-
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(c.tx.ctx, c.client, opts, func() spanner.TransactionOptions {
1437+
// Add the current value of transaction_timeout to the context that is registered
1438+
// on the transaction.
1439+
ctx, cancel := c.addTransactionTimeout(c.tx.ctx)
1440+
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
12871441
defer func() {
12881442
// Reset the transaction_tag after starting the transaction.
12891443
_ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser)
12901444
}()
12911445
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true))
12921446
})
12931447
if err != nil {
1448+
cancel()
12941449
return nil, err
12951450
}
12961451
logger := c.logger.With("tx", "rw")
12971452
return &readWriteTransaction{
1298-
ctx: c.tx.ctx,
1453+
ctx: ctx,
12991454
conn: c,
13001455
logger: logger,
13011456
rwTx: tx,
@@ -1307,6 +1462,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071462
} else {
13081463
closeFunc(txResultRollback)
13091464
}
1465+
cancel()
13101466
},
13111467
retryAborts: sync.OnceValue(func() bool {
13121468
return c.RetryAbortsInternally()
@@ -1371,7 +1527,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711527
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
13721528
}
13731529

1374-
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
1530+
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.CancelFunc, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
1531+
// The cancel() function is called by the returned Rows object when it is closed.
1532+
// However, if an error is returned instead of a Rows instance, we need to cancel
1533+
// the context when we return from this function.
1534+
defer func() {
1535+
if returnedErr != nil {
1536+
cancel()
1537+
}
1538+
}()
13751539
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true, Isolation: withBatchReadOnly(driver.IsolationLevel(sql.LevelDefault))})
13761540
if err != nil {
13771541
return nil, err
@@ -1383,6 +1547,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
13831547
}
13841548
if rows, ok := r.(*rows); ok {
13851549
rows.close = func() error {
1550+
defer cancel()
13861551
return tx.Commit()
13871552
}
13881553
}

0 commit comments

Comments
 (0)