@@ -831,6 +831,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831831 return & stmt {conn : c , query : parsedSQL , numArgs : len (args ), execOptions : execOptions }, nil
832832}
833833
834+ // Adds any statement or transaction timeout to the given context. The deadline of the returned
835+ // context will be the earliest of:
836+ // 1. Any existing deadline on the input context.
837+ // 2. Any existing transaction deadline.
838+ // 3. A deadline calculated from the current time + the value of statement_timeout.
839+ func (c * conn ) addStatementAndTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc , error ) {
840+ var statementDeadline time.Time
841+ var transactionDeadline time.Time
842+ var deadline time.Time
843+ var hasStatementDeadline bool
844+ var hasTransactionDeadline bool
845+
846+ // Check if the connection has a value for statement_timeout.
847+ statementTimeout := propertyStatementTimeout .GetValueOrDefault (c .state )
848+ if statementTimeout != time .Duration (0 ) {
849+ hasStatementDeadline = true
850+ statementDeadline = time .Now ().Add (statementTimeout )
851+ }
852+ // Check if the current transaction has a deadline.
853+ transactionDeadline , hasTransactionDeadline , err := c .transactionDeadline ()
854+ if err != nil {
855+ return nil , nil , err
856+ }
857+
858+ // If there is no statement_timeout and no current transaction deadline,
859+ // then can just use the input context as-is.
860+ if ! hasStatementDeadline && ! hasTransactionDeadline {
861+ return ctx , func () {}, nil
862+ }
863+
864+ // If there is both a transaction and a statement deadline, then we use the earliest
865+ // of those two.
866+ if hasTransactionDeadline && hasStatementDeadline {
867+ if statementDeadline .Before (transactionDeadline ) {
868+ deadline = statementDeadline
869+ } else {
870+ deadline = transactionDeadline
871+ }
872+ } else if hasStatementDeadline {
873+ deadline = statementDeadline
874+ } else {
875+ deadline = transactionDeadline
876+ }
877+ // context.WithDeadline automatically selects the earliest deadline of
878+ // the existing deadline on the context and the given deadline.
879+ newCtx , cancel := context .WithDeadline (ctx , deadline )
880+ return newCtx , cancel , nil
881+ }
882+
883+ // transactionDeadline returns the deadline of the current transaction
884+ // on the connection. This also activates the transaction if it is not
885+ // yet activated.
886+ func (c * conn ) transactionDeadline () (time.Time , bool , error ) {
887+ if c .tx == nil {
888+ return time.Time {}, false , nil
889+ }
890+ if err := c .tx .ensureActivated (); err != nil {
891+ return time.Time {}, false , err
892+ }
893+ deadline , hasDeadline := c .tx .deadline ()
894+ return deadline , hasDeadline , nil
895+ }
896+
834897func (c * conn ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
835898 // Execute client side statement if it is one.
836899 clientStmt , err := c .parser .ParseClientSideStatement (query )
@@ -849,13 +912,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849912 return c .queryContext (ctx , query , execOptions , args )
850913}
851914
852- func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
915+ func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedRows driver.Rows , returnedErr error ) {
916+ ctx , cancelCause := context .WithCancelCause (ctx )
917+ cancel := func () {
918+ cancelCause (nil )
919+ }
920+ defer func () {
921+ if returnedErr != nil {
922+ cancel ()
923+ }
924+ }()
853925 // Clear the commit timestamp of this connection before we execute the query.
854926 c .clearCommitResponse ()
855927 // Check if the execution options contains an instruction to execute
856928 // a specific partition of a PartitionedQuery.
857929 if pq := execOptions .PartitionedQueryOptions .ExecutePartition .PartitionedQuery ; pq != nil {
858- return pq .execute (ctx , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
930+ return pq .execute (ctx , cancel , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
859931 }
860932
861933 stmt , err := prepareSpannerStmt (c .parser , query , args )
@@ -869,7 +941,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869941 if err != nil {
870942 return nil , err
871943 }
872- return createDriverResultRows (res , execOptions ), nil
944+ return createDriverResultRows (res , cancel , execOptions ), nil
873945 }
874946 var iter rowIterator
875947 if c .tx == nil {
@@ -884,7 +956,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884956 } else if execOptions .PartitionedQueryOptions .PartitionQuery {
885957 return nil , spanner .ToSpannerError (status .Errorf (codes .FailedPrecondition , "PartitionQuery is only supported in batch read-only transactions" ))
886958 } else if execOptions .PartitionedQueryOptions .AutoPartitionQuery {
887- return c .executeAutoPartitionedQuery (ctx , query , execOptions , args )
959+ return c .executeAutoPartitionedQuery (ctx , cancel , query , execOptions , args )
888960 } else {
889961 // The statement was either detected as being a query, or potentially not recognized at all.
890962 // In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,25 +965,75 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893965 }
894966 } else {
895967 if execOptions .PartitionedQueryOptions .PartitionQuery {
968+ // The driver.Rows instance that is returned for partitionQuery does not
969+ // contain a context, and therefore also does not cancel the context when it is closed.
970+ defer cancel ()
896971 return c .tx .partitionQuery (ctx , stmt , execOptions )
897972 }
898973 iter , err = c .tx .Query (ctx , stmt , statementInfo .StatementType , execOptions )
899974 if err != nil {
900975 return nil , err
901976 }
902977 }
903- res := createRows (iter , execOptions )
978+ res := createRows (iter , cancel , execOptions )
904979 if execOptions .DirectExecuteQuery {
905- // This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906- res .getColumns ()
907- if res .dirtyErr != nil && ! errors .Is (res .dirtyErr , iterator .Done ) {
908- _ = res .Close ()
909- return nil , res .dirtyErr
980+ if err := c .directExecuteQuery (ctx , cancelCause , res , execOptions ); err != nil {
981+ return nil , err
910982 }
911983 }
912984 return res , nil
913985}
914986
987+ // directExecuteQuery blocks until the first PartialResultSet has been returned by Spanner. Any statement_timeout and/or
988+ // transaction_timeout is used while waiting for the first result to be returned.
989+ func (c * conn ) directExecuteQuery (ctx context.Context , cancelQuery context.CancelCauseFunc , res * rows , execOptions * ExecOptions ) error {
990+ statementCtx := ctx
991+ if execOptions .DirectExecuteContext != nil {
992+ statementCtx = execOptions .DirectExecuteContext
993+ }
994+ // Add the statement or transaction deadline to the context.
995+ statementCtx , cancelStatement , err := c .addStatementAndTransactionTimeout (statementCtx )
996+ if err != nil {
997+ return err
998+ }
999+ defer cancelStatement ()
1000+
1001+ // Asynchronously fetch the first partial result set from Spanner.
1002+ done := make (chan struct {})
1003+ go func () {
1004+ // Calling res.getColumns() ensures that the first PartialResultSet has been returned, as it contains the
1005+ // metadata of the query.
1006+ defer close (done )
1007+ res .getColumns ()
1008+ }()
1009+ // Wait until either the done channel is closed or the context is done.
1010+ var statementErr error
1011+ select {
1012+ case <- statementCtx .Done ():
1013+ statementErr = statementCtx .Err ()
1014+ // Cancel the query execution.
1015+ cancelQuery (statementCtx .Err ())
1016+ case <- done :
1017+ }
1018+
1019+ // Now wait until done channel is closed. This could be because the execution finished
1020+ // successfully, or because the context was cancelled, which again causes the execution
1021+ // to (eventually) fail.
1022+ <- done
1023+ if res .dirtyErr != nil && ! errors .Is (res .dirtyErr , iterator .Done ) {
1024+ _ = res .Close ()
1025+ if statementErr != nil {
1026+ // Create a status error from the statement error and wrap both the Spanner error and the status error into
1027+ // one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request
1028+ // ID from the Spanner error.
1029+ s := status .FromContextError (statementErr )
1030+ return errors .Join (s .Err (), res .dirtyErr )
1031+ }
1032+ return res .dirtyErr
1033+ }
1034+ return nil
1035+ }
1036+
9151037func (c * conn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
9161038 // Execute client side statement if it is one.
9171039 stmt , err := c .parser .ParseClientSideStatement (query )
@@ -929,7 +1051,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
9291051 return c .execContext (ctx , query , execOptions , args )
9301052}
9311053
932- func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Result , error ) {
1054+ func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedResult driver.Result , returnedErr error ) {
1055+ // Add the statement/transaction deadline to the context.
1056+ ctx , cancel , err := c .addStatementAndTransactionTimeout (ctx )
1057+ if err != nil {
1058+ return nil , err
1059+ }
1060+ defer cancel ()
9331061 // Clear the commit timestamp of this connection before we execute the statement.
9341062 c .clearCommitResponse ()
9351063
@@ -1041,6 +1169,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411169 return noTransaction ()
10421170 }
10431171 c .tx = c .prevTx
1172+ // If the aborted error happened during the Commit, then the transaction
1173+ // context has been cancelled, and we need to create a new one.
1174+ if rwTx , ok := c .tx .contextTransaction .(* readWriteTransaction ); ok {
1175+ newCtx , cancel := c .addTransactionTimeout (c .tx .ctx )
1176+ rwTx .ctx = newCtx
1177+ // Make sure that we cancel the new context when the transaction is closed.
1178+ origClose := rwTx .close
1179+ rwTx .close = func (result txResult , commitResponse * spanner.CommitResponse , commitErr error ) {
1180+ origClose (result , commitResponse , commitErr )
1181+ cancel ()
1182+ }
1183+ }
10441184 c .resetForRetry = true
10451185 } else if c .tx == nil {
10461186 return noTransaction ()
@@ -1248,6 +1388,17 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481388 return c .tx , nil
12491389}
12501390
1391+ // addTransactionTimeout creates a new derived context with the current transaction_timeout.
1392+ func (c * conn ) addTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc ) {
1393+ timeout := propertyTransactionTimeout .GetValueOrDefault (c .state )
1394+ if timeout == time .Duration (0 ) {
1395+ return ctx , func () {}
1396+ }
1397+ // Note that this will set the actual deadline to the earliest of the existing deadline on ctx and the calculated
1398+ // deadline based on the timeout.
1399+ return context .WithTimeout (ctx , timeout )
1400+ }
1401+
12511402func (c * conn ) activateTransaction () (contextTransaction , error ) {
12521403 closeFunc := c .tx .close
12531404 if propertyTransactionReadOnly .GetValueOrDefault (c .state ) {
@@ -1283,19 +1434,23 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831434 opts := spanner.TransactionOptions {}
12841435 opts .BeginTransactionOption = c .convertDefaultBeginTransactionOption (propertyBeginTransactionOption .GetValueOrDefault (c .state ))
12851436
1286- tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (c .tx .ctx , c .client , opts , func () spanner.TransactionOptions {
1437+ // Add the current value of transaction_timeout to the context that is registered
1438+ // on the transaction.
1439+ ctx , cancel := c .addTransactionTimeout (c .tx .ctx )
1440+ tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (ctx , c .client , opts , func () spanner.TransactionOptions {
12871441 defer func () {
12881442 // Reset the transaction_tag after starting the transaction.
12891443 _ = propertyTransactionTag .ResetValue (c .state , connectionstate .ContextUser )
12901444 }()
12911445 return c .effectiveTransactionOptions (spannerpb .TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED , c .options ( /*reset=*/ true ))
12921446 })
12931447 if err != nil {
1448+ cancel ()
12941449 return nil , err
12951450 }
12961451 logger := c .logger .With ("tx" , "rw" )
12971452 return & readWriteTransaction {
1298- ctx : c . tx . ctx ,
1453+ ctx : ctx ,
12991454 conn : c ,
13001455 logger : logger ,
13011456 rwTx : tx ,
@@ -1307,6 +1462,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071462 } else {
13081463 closeFunc (txResultRollback )
13091464 }
1465+ cancel ()
13101466 },
13111467 retryAborts : sync .OnceValue (func () bool {
13121468 return c .RetryAbortsInternally ()
@@ -1371,7 +1527,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711527 return c .Single ().WithTimestampBound (tb ).QueryWithOptions (ctx , statement , options .QueryOptions )
13721528}
13731529
1374- func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
1530+ func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , cancel context.CancelFunc , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedRows driver.Rows , returnedErr error ) {
1531+ // The cancel() function is called by the returned Rows object when it is closed.
1532+ // However, if an error is returned instead of a Rows instance, we need to cancel
1533+ // the context when we return from this function.
1534+ defer func () {
1535+ if returnedErr != nil {
1536+ cancel ()
1537+ }
1538+ }()
13751539 tx , err := c .BeginTx (ctx , driver.TxOptions {ReadOnly : true , Isolation : withBatchReadOnly (driver .IsolationLevel (sql .LevelDefault ))})
13761540 if err != nil {
13771541 return nil , err
@@ -1383,6 +1547,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
13831547 }
13841548 if rows , ok := r .(* rows ); ok {
13851549 rows .close = func () error {
1550+ defer cancel ()
13861551 return tx .Commit ()
13871552 }
13881553 }
0 commit comments