@@ -831,6 +831,51 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831831 return & stmt {conn : c , query : parsedSQL , numArgs : len (args ), execOptions : execOptions }, nil
832832}
833833
834+ func (c * conn ) addStatementAndTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc , error ) {
835+ var statementDeadline time.Time
836+ var transactionDeadline time.Time
837+ var deadline time.Time
838+ var hasStatementDeadline bool
839+ var hasTransactionDeadline bool
840+
841+ statementTimeout := propertyStatementTimeout .GetValueOrDefault (c .state )
842+ if statementTimeout != time .Duration (0 ) {
843+ hasStatementDeadline = true
844+ statementDeadline = time .Now ().Add (statementTimeout )
845+ }
846+ transactionDeadline , hasTransactionDeadline , err := c .transactionDeadline ()
847+ if err != nil {
848+ return nil , nil , err
849+ }
850+ if ! hasStatementDeadline && ! hasTransactionDeadline {
851+ return ctx , func () {}, nil
852+ }
853+ if hasTransactionDeadline && hasStatementDeadline {
854+ if statementDeadline .Before (transactionDeadline ) {
855+ deadline = statementDeadline
856+ } else {
857+ deadline = transactionDeadline
858+ }
859+ } else if hasStatementDeadline {
860+ deadline = statementDeadline
861+ } else {
862+ deadline = transactionDeadline
863+ }
864+ newCtx , cancel := context .WithDeadline (ctx , deadline )
865+ return newCtx , cancel , nil
866+ }
867+
868+ func (c * conn ) transactionDeadline () (time.Time , bool , error ) {
869+ if c .tx == nil {
870+ return time.Time {}, false , nil
871+ }
872+ if err := c .tx .ensureActivated (); err != nil {
873+ return time.Time {}, false , err
874+ }
875+ deadline , hasDeadline := c .tx .deadline ()
876+ return deadline , hasDeadline , nil
877+ }
878+
834879func (c * conn ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
835880 // Execute client side statement if it is one.
836881 clientStmt , err := c .parser .ParseClientSideStatement (query )
@@ -849,13 +894,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849894 return c .queryContext (ctx , query , execOptions , args )
850895}
851896
852- func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
897+ func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedRows driver.Rows , returnedErr error ) {
898+ ctx , cancel , err := c .addStatementAndTransactionTimeout (ctx )
899+ if err != nil {
900+ return nil , err
901+ }
902+ defer func () {
903+ if returnedErr != nil {
904+ cancel ()
905+ }
906+ }()
853907 // Clear the commit timestamp of this connection before we execute the query.
854908 c .clearCommitResponse ()
855909 // Check if the execution options contains an instruction to execute
856910 // a specific partition of a PartitionedQuery.
857911 if pq := execOptions .PartitionedQueryOptions .ExecutePartition .PartitionedQuery ; pq != nil {
858- return pq .execute (ctx , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
912+ return pq .execute (ctx , cancel , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
859913 }
860914
861915 stmt , err := prepareSpannerStmt (c .parser , query , args )
@@ -869,7 +923,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869923 if err != nil {
870924 return nil , err
871925 }
872- return createDriverResultRows (res , execOptions ), nil
926+ return createDriverResultRows (res , cancel , execOptions ), nil
873927 }
874928 var iter rowIterator
875929 if c .tx == nil {
@@ -884,7 +938,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884938 } else if execOptions .PartitionedQueryOptions .PartitionQuery {
885939 return nil , spanner .ToSpannerError (status .Errorf (codes .FailedPrecondition , "PartitionQuery is only supported in batch read-only transactions" ))
886940 } else if execOptions .PartitionedQueryOptions .AutoPartitionQuery {
887- return c .executeAutoPartitionedQuery (ctx , query , execOptions , args )
941+ return c .executeAutoPartitionedQuery (ctx , cancel , query , execOptions , args )
888942 } else {
889943 // The statement was either detected as being a query, or potentially not recognized at all.
890944 // In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,14 +947,15 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893947 }
894948 } else {
895949 if execOptions .PartitionedQueryOptions .PartitionQuery {
950+ defer cancel ()
896951 return c .tx .partitionQuery (ctx , stmt , execOptions )
897952 }
898953 iter , err = c .tx .Query (ctx , stmt , statementInfo .StatementType , execOptions )
899954 if err != nil {
900955 return nil , err
901956 }
902957 }
903- res := createRows (iter , execOptions )
958+ res := createRows (iter , cancel , execOptions )
904959 if execOptions .DirectExecuteQuery {
905960 // This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906961 res .getColumns ()
@@ -929,7 +984,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
929984 return c .execContext (ctx , query , execOptions , args )
930985}
931986
932- func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Result , error ) {
987+ func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedResult driver.Result , returnedErr error ) {
988+ ctx , cancel , err := c .addStatementAndTransactionTimeout (ctx )
989+ if err != nil {
990+ return nil , err
991+ }
992+ defer func () {
993+ if returnedErr != nil {
994+ cancel ()
995+ }
996+ }()
933997 // Clear the commit timestamp of this connection before we execute the statement.
934998 c .clearCommitResponse ()
935999
@@ -1041,6 +1105,17 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411105 return noTransaction ()
10421106 }
10431107 c .tx = c .prevTx
1108+ // If the aborted error happened during the Commit, then the transaction context has been cancelled,
1109+ // and we need to create a new one.
1110+ if rwTx , ok := c .tx .contextTransaction .(* readWriteTransaction ); ok {
1111+ newCtx , cancel := c .addTransactionTimeout (c .tx .ctx )
1112+ rwTx .ctx = newCtx
1113+ origClose := rwTx .close
1114+ rwTx .close = func (result txResult , commitResponse * spanner.CommitResponse , commitErr error ) {
1115+ origClose (result , commitResponse , commitErr )
1116+ cancel ()
1117+ }
1118+ }
10441119 c .resetForRetry = true
10451120 } else if c .tx == nil {
10461121 return noTransaction ()
@@ -1248,6 +1323,14 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481323 return c .tx , nil
12491324}
12501325
1326+ func (c * conn ) addTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc ) {
1327+ timeout := propertyTransactionTimeout .GetValueOrDefault (c .state )
1328+ if timeout == time .Duration (0 ) {
1329+ return ctx , func () {}
1330+ }
1331+ return context .WithTimeout (ctx , timeout )
1332+ }
1333+
12511334func (c * conn ) activateTransaction () (contextTransaction , error ) {
12521335 closeFunc := c .tx .close
12531336 if propertyTransactionReadOnly .GetValueOrDefault (c .state ) {
@@ -1283,19 +1366,21 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831366 opts := spanner.TransactionOptions {}
12841367 opts .BeginTransactionOption = c .convertDefaultBeginTransactionOption (propertyBeginTransactionOption .GetValueOrDefault (c .state ))
12851368
1286- tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (c .tx .ctx , c .client , opts , func () spanner.TransactionOptions {
1369+ ctx , cancel := c .addTransactionTimeout (c .tx .ctx )
1370+ tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (ctx , c .client , opts , func () spanner.TransactionOptions {
12871371 defer func () {
12881372 // Reset the transaction_tag after starting the transaction.
12891373 _ = propertyTransactionTag .ResetValue (c .state , connectionstate .ContextUser )
12901374 }()
12911375 return c .effectiveTransactionOptions (spannerpb .TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED , c .options ( /*reset=*/ true ))
12921376 })
12931377 if err != nil {
1378+ cancel ()
12941379 return nil , err
12951380 }
12961381 logger := c .logger .With ("tx" , "rw" )
12971382 return & readWriteTransaction {
1298- ctx : c . tx . ctx ,
1383+ ctx : ctx ,
12991384 conn : c ,
13001385 logger : logger ,
13011386 rwTx : tx ,
@@ -1307,6 +1392,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071392 } else {
13081393 closeFunc (txResultRollback )
13091394 }
1395+ cancel ()
13101396 },
13111397 retryAborts : sync .OnceValue (func () bool {
13121398 return c .RetryAbortsInternally ()
@@ -1371,18 +1457,20 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711457 return c .Single ().WithTimestampBound (tb ).QueryWithOptions (ctx , statement , options .QueryOptions )
13721458}
13731459
1374- func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
1460+ func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , cancel context. CancelFunc , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
13751461 tx , err := c .BeginTx (ctx , driver.TxOptions {ReadOnly : true , Isolation : withBatchReadOnly (driver .IsolationLevel (sql .LevelDefault ))})
13761462 if err != nil {
13771463 return nil , err
13781464 }
13791465 r , err := c .queryContext (ctx , query , execOptions , args )
13801466 if err != nil {
13811467 _ = tx .Rollback ()
1468+ cancel ()
13821469 return nil , err
13831470 }
13841471 if rows , ok := r .(* rows ); ok {
13851472 rows .close = func () error {
1473+ defer cancel ()
13861474 return tx .Commit ()
13871475 }
13881476 }
0 commit comments