diff --git a/spanner/transaction.go b/spanner/transaction.go index f9bce761fa54..0666388e2bf1 100644 --- a/spanner/transaction.go +++ b/spanner/transaction.go @@ -432,6 +432,10 @@ func errMultipleRowsFound(table string, key Key, index string) error { return spannerErrorf(codes.FailedPrecondition, "more than one row found by index(Table: %v, IndexKey: %v, Index: %v)", table, key, index) } +func errTransactionNoLongerActive() error { + return spannerError(codes.FailedPrecondition, "the transaction that was used to execute this statement is no longer active") +} + const errInlineBeginTransactionFailedMsg = "failed inline begin transaction" // errInlineBeginTransactionFailed creates an error that indicates that the first statement in the @@ -701,6 +705,12 @@ func (t *txReadOnly) query(ctx context.Context, statement Statement, options Que sh.session.logger, t.sp.sc.metricsTracerFactory, func(ctx context.Context, resumeToken []byte, opts ...gax.CallOption) (streamingReceiver, error) { + // The session handle is removed from the transaction when the transaction is committed or rolled back. + // This ensures that we return a reasonable error instead of panic if the application tries to use the + // stream after the transaction has finished. + if t.sh == nil { + return nil, errTransactionNoLongerActive() + } req.ResumeToken = resumeToken req.Session = t.sh.getID() req.Transaction = t.getTransactionSelector() diff --git a/spanner/transaction_test.go b/spanner/transaction_test.go index e496a557f258..445ef7226834 100644 --- a/spanner/transaction_test.go +++ b/spanner/transaction_test.go @@ -2222,6 +2222,53 @@ func TestReadWriteTransactionUsesNewContextForRollback(t *testing.T) { } } +func TestReadFromQueryAfterCommitOrRollback(t *testing.T) { + t.Parallel() + + ctx := context.Background() + _, client, teardown := setupMockedTestServer(t) + defer teardown() + + testcases := []struct { + name string + commit bool + }{ + {name: "AfterCommit", commit: true}, + {name: "AfterRollback", commit: false}, + } + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + // Create a new transaction and execute a query using that transaction. + // Then try to read data from the row iterator after the transaction has finished. + tx, err := NewReadWriteStmtBasedTransaction(ctx, client) + if err != nil { + t.Fatalf("failed to create transaction: %v", err) + } + // 'Execute' the query using the transaction. Note that the query is only actually executed the first time + // that RowIterator.Next() is called. + it := tx.Query(ctx, NewStatement(SelectFooFromBar)) + // Commit or rollback the transaction before reading any data. + if tc.commit { + if _, err := tx.Commit(ctx); err != nil { + t.Fatalf("failed to commit: %v", err) + } + } else { + tx.Rollback(ctx) + } + + // Now try to read the data from the RowIterator that was returned for the query. + _, err = it.Next() + if g, w := ErrCode(err), codes.FailedPrecondition; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := err.Error(), `spanner: code = "FailedPrecondition", desc = "the transaction that was used to execute this statement is no longer active"`; g != w { + t.Fatalf("error message mismatch\n Got: %v\nWant: %v", g, w) + } + it.Stop() + }) + } +} + // shouldHaveReceived asserts that exactly expectedRequests were present in // the server's ReceivedRequests channel. It only looks at type, not contents. //