@@ -19,6 +19,7 @@ import (
1919 "database/sql"
2020 "database/sql/driver"
2121 "errors"
22+ "fmt"
2223 "log/slog"
2324 "slices"
2425 "sync"
@@ -831,6 +832,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831832 return & stmt {conn : c , query : parsedSQL , numArgs : len (args ), execOptions : execOptions }, nil
832833}
833834
835+ // Adds any statement or transaction timeout to the given context. The deadline of the returned
836+ // context will be the earliest of:
837+ // 1. Any existing deadline on the input context.
838+ // 2. Any existing transaction deadline.
839+ // 3. A deadline calculated from the current time + the value of statement_timeout.
840+ func (c * conn ) addStatementAndTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc , error ) {
841+ var statementDeadline time.Time
842+ var transactionDeadline time.Time
843+ var deadline time.Time
844+ var hasStatementDeadline bool
845+ var hasTransactionDeadline bool
846+
847+ // Check if the connection has a value for statement_timeout.
848+ statementTimeout := propertyStatementTimeout .GetValueOrDefault (c .state )
849+ if statementTimeout != time .Duration (0 ) {
850+ hasStatementDeadline = true
851+ statementDeadline = time .Now ().Add (statementTimeout )
852+ }
853+ // Check if the current transaction has a deadline.
854+ transactionDeadline , hasTransactionDeadline , err := c .transactionDeadline ()
855+ if err != nil {
856+ return nil , nil , err
857+ }
858+
859+ // If there is no statement_timeout and no current transaction deadline,
860+ // then can just use the input context as-is.
861+ if ! hasStatementDeadline && ! hasTransactionDeadline {
862+ return ctx , func () {}, nil
863+ }
864+
865+ // If there is both a transaction and a statement deadline, then we use the earliest
866+ // of those two.
867+ if hasTransactionDeadline && hasStatementDeadline {
868+ if statementDeadline .Before (transactionDeadline ) {
869+ deadline = statementDeadline
870+ } else {
871+ deadline = transactionDeadline
872+ }
873+ } else if hasStatementDeadline {
874+ deadline = statementDeadline
875+ } else {
876+ deadline = transactionDeadline
877+ }
878+ // context.WithDeadline automatically selects the earliest deadline of
879+ // the existing deadline on the context and the given deadline.
880+ newCtx , cancel := context .WithDeadline (ctx , deadline )
881+ return newCtx , cancel , nil
882+ }
883+
884+ // transactionDeadline returns the deadline of the current transaction
885+ // on the connection. This also activates the transaction if it is not
886+ // yet activated.
887+ func (c * conn ) transactionDeadline () (time.Time , bool , error ) {
888+ if c .tx == nil {
889+ return time.Time {}, false , nil
890+ }
891+ if err := c .tx .ensureActivated (); err != nil {
892+ return time.Time {}, false , err
893+ }
894+ deadline , hasDeadline := c .tx .deadline ()
895+ return deadline , hasDeadline , nil
896+ }
897+
834898func (c * conn ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
835899 // Execute client side statement if it is one.
836900 clientStmt , err := c .parser .ParseClientSideStatement (query )
@@ -849,13 +913,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849913 return c .queryContext (ctx , query , execOptions , args )
850914}
851915
852- func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
916+ func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedRows driver.Rows , returnedErr error ) {
917+ ctx , cancelCause := context .WithCancelCause (ctx )
918+ cancel := func () {
919+ cancelCause (nil )
920+ }
921+ defer func () {
922+ if returnedErr != nil {
923+ cancel ()
924+ }
925+ }()
853926 // Clear the commit timestamp of this connection before we execute the query.
854927 c .clearCommitResponse ()
855928 // Check if the execution options contains an instruction to execute
856929 // a specific partition of a PartitionedQuery.
857930 if pq := execOptions .PartitionedQueryOptions .ExecutePartition .PartitionedQuery ; pq != nil {
858- return pq .execute (ctx , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
931+ return pq .execute (ctx , cancel , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
859932 }
860933
861934 stmt , err := prepareSpannerStmt (c .parser , query , args )
@@ -869,7 +942,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869942 if err != nil {
870943 return nil , err
871944 }
872- return createDriverResultRows (res , execOptions ), nil
945+ return createDriverResultRows (res , cancel , execOptions ), nil
873946 }
874947 var iter rowIterator
875948 if c .tx == nil {
@@ -884,7 +957,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884957 } else if execOptions .PartitionedQueryOptions .PartitionQuery {
885958 return nil , spanner .ToSpannerError (status .Errorf (codes .FailedPrecondition , "PartitionQuery is only supported in batch read-only transactions" ))
886959 } else if execOptions .PartitionedQueryOptions .AutoPartitionQuery {
887- return c .executeAutoPartitionedQuery (ctx , query , execOptions , args )
960+ return c .executeAutoPartitionedQuery (ctx , cancel , query , execOptions , args )
888961 } else {
889962 // The statement was either detected as being a query, or potentially not recognized at all.
890963 // In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,25 +966,71 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893966 }
894967 } else {
895968 if execOptions .PartitionedQueryOptions .PartitionQuery {
969+ // The driver.Rows instance that is returned for partitionQuery does not
970+ // contain a context, and therefore also does not cancel the context when it is closed.
971+ defer cancel ()
896972 return c .tx .partitionQuery (ctx , stmt , execOptions )
897973 }
898974 iter , err = c .tx .Query (ctx , stmt , statementInfo .StatementType , execOptions )
899975 if err != nil {
900976 return nil , err
901977 }
902978 }
903- res := createRows (iter , execOptions )
979+ res := createRows (iter , cancel , execOptions )
904980 if execOptions .DirectExecuteQuery {
905- // This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906- res .getColumns ()
907- if res .dirtyErr != nil && ! errors .Is (res .dirtyErr , iterator .Done ) {
908- _ = res .Close ()
909- return nil , res .dirtyErr
981+ if err := c .directExecuteQuery (ctx , cancelCause , res , execOptions ); err != nil {
982+ return nil , err
910983 }
911984 }
912985 return res , nil
913986}
914987
988+ func (c * conn ) directExecuteQuery (ctx context.Context , cancelQuery context.CancelCauseFunc , res * rows , execOptions * ExecOptions ) error {
989+ statementCtx := ctx
990+ if execOptions .DirectExecuteQuery && execOptions .DirectExecuteContext != nil {
991+ statementCtx = execOptions .DirectExecuteContext
992+ }
993+ // Add the statement or transaction deadline to the context.
994+ statementCtx , cancelStatement , err := c .addStatementAndTransactionTimeout (statementCtx )
995+ if err != nil {
996+ return err
997+ }
998+ defer cancelStatement ()
999+
1000+ // Asynchronously fetch the first partial result set from Spanner.
1001+ done := make (chan struct {})
1002+ go func () {
1003+ defer close (done )
1004+ res .getColumns ()
1005+ }()
1006+ // Wait until either the done channel is closed or the context is done.
1007+ var statementErr error
1008+ select {
1009+ case <- statementCtx .Done ():
1010+ statementErr = statementCtx .Err ()
1011+ // Cancel the execution.
1012+ cancelQuery (statementCtx .Err ())
1013+ case <- done :
1014+ }
1015+
1016+ // Now wait until done channel is closed. This could be because the execution finished
1017+ // successfully, or because the context was cancelled, which again causes the execution
1018+ // to (eventually) fail.
1019+ <- done
1020+ if res .dirtyErr != nil && ! errors .Is (res .dirtyErr , iterator .Done ) {
1021+ _ = res .Close ()
1022+ if statementErr != nil {
1023+ // Create a status error from the statement error and wrap both the Spanner error and the status error into
1024+ // one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request
1025+ // ID from the Spanner error.
1026+ s := status .FromContextError (statementErr )
1027+ return fmt .Errorf ("%w: %w" , s .Err (), res .dirtyErr )
1028+ }
1029+ return res .dirtyErr
1030+ }
1031+ return nil
1032+ }
1033+
9151034func (c * conn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
9161035 // Execute client side statement if it is one.
9171036 stmt , err := c .parser .ParseClientSideStatement (query )
@@ -929,7 +1048,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
9291048 return c .execContext (ctx , query , execOptions , args )
9301049}
9311050
932- func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Result , error ) {
1051+ func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedResult driver.Result , returnedErr error ) {
1052+ // Add the statement/transaction deadline to the context.
1053+ ctx , cancel , err := c .addStatementAndTransactionTimeout (ctx )
1054+ if err != nil {
1055+ return nil , err
1056+ }
1057+ defer cancel ()
9331058 // Clear the commit timestamp of this connection before we execute the statement.
9341059 c .clearCommitResponse ()
9351060
@@ -1041,6 +1166,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411166 return noTransaction ()
10421167 }
10431168 c .tx = c .prevTx
1169+ // If the aborted error happened during the Commit, then the transaction
1170+ // context has been cancelled, and we need to create a new one.
1171+ if rwTx , ok := c .tx .contextTransaction .(* readWriteTransaction ); ok {
1172+ newCtx , cancel := c .addTransactionTimeout (c .tx .ctx )
1173+ rwTx .ctx = newCtx
1174+ // Make sure that we cancel the new context when the transaction is closed.
1175+ origClose := rwTx .close
1176+ rwTx .close = func (result txResult , commitResponse * spanner.CommitResponse , commitErr error ) {
1177+ origClose (result , commitResponse , commitErr )
1178+ cancel ()
1179+ }
1180+ }
10441181 c .resetForRetry = true
10451182 } else if c .tx == nil {
10461183 return noTransaction ()
@@ -1248,6 +1385,15 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481385 return c .tx , nil
12491386}
12501387
1388+ // addTransactionTimeout adds the current transaction_timeout to the given context.
1389+ func (c * conn ) addTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc ) {
1390+ timeout := propertyTransactionTimeout .GetValueOrDefault (c .state )
1391+ if timeout == time .Duration (0 ) {
1392+ return ctx , func () {}
1393+ }
1394+ return context .WithTimeout (ctx , timeout )
1395+ }
1396+
12511397func (c * conn ) activateTransaction () (contextTransaction , error ) {
12521398 closeFunc := c .tx .close
12531399 if propertyTransactionReadOnly .GetValueOrDefault (c .state ) {
@@ -1283,19 +1429,23 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831429 opts := spanner.TransactionOptions {}
12841430 opts .BeginTransactionOption = c .convertDefaultBeginTransactionOption (propertyBeginTransactionOption .GetValueOrDefault (c .state ))
12851431
1286- tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (c .tx .ctx , c .client , opts , func () spanner.TransactionOptions {
1432+ // Add the current value of transaction_timeout to the context that is registered
1433+ // on the transaction.
1434+ ctx , cancel := c .addTransactionTimeout (c .tx .ctx )
1435+ tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (ctx , c .client , opts , func () spanner.TransactionOptions {
12871436 defer func () {
12881437 // Reset the transaction_tag after starting the transaction.
12891438 _ = propertyTransactionTag .ResetValue (c .state , connectionstate .ContextUser )
12901439 }()
12911440 return c .effectiveTransactionOptions (spannerpb .TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED , c .options ( /*reset=*/ true ))
12921441 })
12931442 if err != nil {
1443+ cancel ()
12941444 return nil , err
12951445 }
12961446 logger := c .logger .With ("tx" , "rw" )
12971447 return & readWriteTransaction {
1298- ctx : c . tx . ctx ,
1448+ ctx : ctx ,
12991449 conn : c ,
13001450 logger : logger ,
13011451 rwTx : tx ,
@@ -1307,6 +1457,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071457 } else {
13081458 closeFunc (txResultRollback )
13091459 }
1460+ cancel ()
13101461 },
13111462 retryAborts : sync .OnceValue (func () bool {
13121463 return c .RetryAbortsInternally ()
@@ -1371,7 +1522,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711522 return c .Single ().WithTimestampBound (tb ).QueryWithOptions (ctx , statement , options .QueryOptions )
13721523}
13731524
1374- func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
1525+ func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , cancel context.CancelFunc , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedRows driver.Rows , returnedErr error ) {
1526+ // The cancel() function is called by the returned Rows object when it is closed.
1527+ // However, if an error is returned instead of a Rows instance, we need to cancel
1528+ // the context when we return from this function.
1529+ defer func () {
1530+ if returnedErr != nil {
1531+ cancel ()
1532+ }
1533+ }()
13751534 tx , err := c .BeginTx (ctx , driver.TxOptions {ReadOnly : true , Isolation : withBatchReadOnly (driver .IsolationLevel (sql .LevelDefault ))})
13761535 if err != nil {
13771536 return nil , err
@@ -1383,6 +1542,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
13831542 }
13841543 if rows , ok := r .(* rows ); ok {
13851544 rows .close = func () error {
1545+ defer cancel ()
13861546 return tx .Commit ()
13871547 }
13881548 }
0 commit comments