Skip to content

Commit a2fe03d

Browse files
committed
chore: store temp TransactionOptions in connection state
Store temporary TransactionOptions in the connection state as local options. Local options only apply to the current transaction. This simplifies the internal state handling of the driver, as all transaction state should only be read from the connection state, and not also from a temporary variable. This also enables the use of a combination of temporary transaction options and using SQL statements to set further options. The shared library always includes temporary transaction options, as the BeginTransaction function accepts TransactionOptions as an input argument. This meant that using SQL statements to set further transaction options was not supported through the shared library.
1 parent 13bda8d commit a2fe03d

File tree

5 files changed

+190
-29
lines changed

5 files changed

+190
-29
lines changed

conn.go

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,9 @@ type conn struct {
275275
// tempExecOptions can be set by passing it in as an argument to ExecContext or QueryContext
276276
// and are applied only to that statement.
277277
tempExecOptions *ExecOptions
278-
// tempTransactionOptions are temporarily set right before a read/write transaction is started.
279-
tempTransactionOptions *ReadWriteTransactionOptions
278+
// tempTransactionCloseFunc is set right before a transaction is started, and is set as the
279+
// close function for that transaction.
280+
tempTransactionCloseFunc func()
280281
// tempReadOnlyTransactionOptions are temporarily set right before a read-only
281282
// transaction is started on a Spanner connection.
282283
tempReadOnlyTransactionOptions *ReadOnlyTransactionOptions
@@ -1011,8 +1012,10 @@ func (c *conn) options(reset bool) *ExecOptions {
10111012
TransactionTag: c.TransactionTag(),
10121013
IsolationLevel: toProtoIsolationLevelOrDefault(c.IsolationLevel()),
10131014
ReadLockMode: c.ReadLockMode(),
1015+
CommitPriority: propertyCommitPriority.GetValueOrDefault(c.state),
10141016
CommitOptions: spanner.CommitOptions{
1015-
MaxCommitDelay: c.maxCommitDelayPointer(),
1017+
MaxCommitDelay: c.maxCommitDelayPointer(),
1018+
ReturnCommitStats: propertyReturnCommitStats.GetValueOrDefault(c.state),
10161019
},
10171020
},
10181021
PartitionedQueryOptions: PartitionedQueryOptions{},
@@ -1045,16 +1048,43 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10451048
}
10461049

10471050
func (c *conn) withTempTransactionOptions(options *ReadWriteTransactionOptions) {
1048-
c.tempTransactionOptions = options
1051+
if options == nil {
1052+
return
1053+
}
1054+
c.tempTransactionCloseFunc = options.close
1055+
// Start a transaction for the connection state, so we can set the transaction options
1056+
// as local options in the current transaction.
1057+
_ = c.state.Begin()
1058+
if options.DisableInternalRetries {
1059+
_ = propertyRetryAbortsInternally.SetLocalValue(c.state, !options.DisableInternalRetries)
1060+
}
1061+
if options.TransactionOptions.BeginTransactionOption != spanner.DefaultBeginTransaction {
1062+
_ = propertyBeginTransactionOption.SetLocalValue(c.state, options.TransactionOptions.BeginTransactionOption)
1063+
}
1064+
if options.TransactionOptions.CommitOptions.MaxCommitDelay != nil {
1065+
_ = propertyMaxCommitDelay.SetLocalValue(c.state, *options.TransactionOptions.CommitOptions.MaxCommitDelay)
1066+
}
1067+
if options.TransactionOptions.CommitOptions.ReturnCommitStats {
1068+
_ = propertyReturnCommitStats.SetLocalValue(c.state, options.TransactionOptions.CommitOptions.ReturnCommitStats)
1069+
}
1070+
if options.TransactionOptions.TransactionTag != "" {
1071+
_ = propertyTransactionTag.SetLocalValue(c.state, options.TransactionOptions.TransactionTag)
1072+
}
1073+
if options.TransactionOptions.ReadLockMode != spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED {
1074+
_ = propertyReadLockMode.SetLocalValue(c.state, options.TransactionOptions.ReadLockMode)
1075+
}
1076+
if options.TransactionOptions.IsolationLevel != spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED {
1077+
_ = propertyIsolationLevel.SetLocalValue(c.state, toSqlIsolationLevelOrDefault(options.TransactionOptions.IsolationLevel))
1078+
}
1079+
if options.TransactionOptions.ExcludeTxnFromChangeStreams {
1080+
_ = propertyExcludeTxnFromChangeStreams.SetLocalValue(c.state, options.TransactionOptions.ExcludeTxnFromChangeStreams)
1081+
}
1082+
if options.TransactionOptions.CommitPriority != spannerpb.RequestOptions_PRIORITY_UNSPECIFIED {
1083+
_ = propertyCommitPriority.SetLocalValue(c.state, options.TransactionOptions.CommitPriority)
1084+
}
10491085
}
10501086

10511087
func (c *conn) getTransactionOptions(execOptions *ExecOptions) ReadWriteTransactionOptions {
1052-
if c.tempTransactionOptions != nil {
1053-
defer func() { c.tempTransactionOptions = nil }()
1054-
opts := *c.tempTransactionOptions
1055-
opts.TransactionOptions.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.TransactionOptions.BeginTransactionOption)
1056-
return opts
1057-
}
10581088
txOpts := ReadWriteTransactionOptions{
10591089
TransactionOptions: execOptions.TransactionOptions,
10601090
DisableInternalRetries: !c.RetryAbortsInternally(),
@@ -1122,7 +1152,6 @@ func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWrite
11221152
c.withTempTransactionOptions(options)
11231153
tx, err := c.BeginTx(ctx, driver.TxOptions{})
11241154
if err != nil {
1125-
c.withTempTransactionOptions(nil)
11261155
return nil, err
11271156
}
11281157
return tx, nil
@@ -1133,6 +1162,13 @@ func (c *conn) Begin() (driver.Tx, error) {
11331162
}
11341163

11351164
func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver.Tx, error) {
1165+
defer func() {
1166+
c.tempTransactionCloseFunc = nil
1167+
}()
1168+
return c.beginTx(ctx, driverOpts, c.tempTransactionCloseFunc)
1169+
}
1170+
1171+
func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFunc func()) (driver.Tx, error) {
11361172
if c.resetForRetry {
11371173
c.resetForRetry = false
11381174
return c.tx, nil
@@ -1141,6 +1177,10 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver
11411177
defer func() {
11421178
if c.tx != nil {
11431179
_ = c.state.Begin()
1180+
} else {
1181+
// Rollback in case the connection state transaction was started before this function
1182+
// was called, for example if the caller set temporary transaction options.
1183+
_ = c.state.Rollback()
11441184
}
11451185
}()
11461186

@@ -1219,17 +1259,18 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver
12191259
return c.tx, nil
12201260
}
12211261

1262+
// These options are only used to determine how to start the transaction.
1263+
// All other options are fetched in a callback that is called when the transaction is actually started.
1264+
// That callback reads all transaction options from the connection state at that moment. This allows
1265+
// applications to execute a series of statement like this:
1266+
// BEGIN TRANSACTION;
1267+
// SET LOCAL transaction_tag='my_tag';
1268+
// SET LOCAL commit_priority=LOW;
1269+
// INSERT INTO my_table ... -- This starts the transaction with the options above included.
12221270
opts := spanner.TransactionOptions{}
1223-
if c.tempTransactionOptions != nil {
1224-
opts = c.tempTransactionOptions.TransactionOptions
1225-
}
1226-
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(opts.BeginTransactionOption)
1227-
tempCloseFunc := func() {}
1228-
if c.tempTransactionOptions != nil && c.tempTransactionOptions.close != nil {
1229-
tempCloseFunc = c.tempTransactionOptions.close
1230-
}
1231-
if !disableRetryAborts && c.tempTransactionOptions != nil {
1232-
disableRetryAborts = c.tempTransactionOptions.DisableInternalRetries
1271+
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))
1272+
if closeFunc == nil {
1273+
closeFunc = func() {}
12331274
}
12341275

12351276
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
@@ -1249,7 +1290,7 @@ func (c *conn) BeginTx(ctx context.Context, driverOpts driver.TxOptions) (driver
12491290
logger: logger,
12501291
rwTx: tx,
12511292
close: func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
1252-
tempCloseFunc()
1293+
closeFunc()
12531294
c.prevTx = c.tx
12541295
c.tx = nil
12551296
if commitErr == nil {

connection_properties.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,27 @@ var propertyMaxCommitDelay = createConnectionProperty(
257257
connectionstate.ContextUser,
258258
connectionstate.ConvertDuration,
259259
)
260+
var propertyCommitPriority = createConnectionProperty(
261+
"commit_priority",
262+
"Sets the priority for commit RPC invocations from this connection (HIGH/MEDIUM/LOW/UNSPECIFIED). "+
263+
"The default is UNSPECIFIED.",
264+
spannerpb.RequestOptions_PRIORITY_UNSPECIFIED,
265+
false,
266+
nil,
267+
connectionstate.ContextUser,
268+
func(value string) (spannerpb.RequestOptions_Priority, error) {
269+
return parseRpcPriority(value)
270+
},
271+
)
272+
var propertyReturnCommitStats = createConnectionProperty(
273+
"return_commit_stats",
274+
"return_commit_stats determines whether transactions should request Spanner to return commit statistics.",
275+
false,
276+
false,
277+
nil,
278+
connectionstate.ContextUser,
279+
connectionstate.ConvertBool,
280+
)
260281

261282
// ------------------------------------------------------------------------------------------------
262283
// Statement connection properties.

driver.go

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,7 +1148,6 @@ func BeginReadWriteTransaction(ctx context.Context, db *sql.DB, options ReadWrit
11481148
}
11491149
tx, err := conn.BeginTx(ctx, &sql.TxOptions{})
11501150
if err != nil {
1151-
clearTempReadWriteTransactionOptions(conn)
11521151
return nil, err
11531152
}
11541153
return tx, nil
@@ -1166,11 +1165,6 @@ func withTempReadWriteTransactionOptions(conn *sql.Conn, options *ReadWriteTrans
11661165
})
11671166
}
11681167

1169-
func clearTempReadWriteTransactionOptions(conn *sql.Conn) {
1170-
_ = withTempReadWriteTransactionOptions(conn, nil)
1171-
_ = conn.Close()
1172-
}
1173-
11741168
// ReadOnlyTransactionOptions can be used to create a read-only transaction
11751169
// on a Spanner connection.
11761170
type ReadOnlyTransactionOptions struct {
@@ -1529,6 +1523,24 @@ func toProtoIsolationLevelOrDefault(level sql.IsolationLevel) spannerpb.Transact
15291523
return res
15301524
}
15311525

1526+
func toSqlIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) (sql.IsolationLevel, error) {
1527+
switch level {
1528+
case spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED:
1529+
return sql.LevelDefault, nil
1530+
case spannerpb.TransactionOptions_SERIALIZABLE:
1531+
return sql.LevelSerializable, nil
1532+
case spannerpb.TransactionOptions_REPEATABLE_READ:
1533+
return sql.LevelRepeatableRead, nil
1534+
default:
1535+
}
1536+
return sql.LevelDefault, spanner.ToSpannerError(status.Errorf(codes.InvalidArgument, "invalid or unsupported isolation level: %v", level))
1537+
}
1538+
1539+
func toSqlIsolationLevelOrDefault(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel {
1540+
res, _ := toSqlIsolationLevel(level)
1541+
return res
1542+
}
1543+
15321544
type spannerIsolationLevel sql.IsolationLevel
15331545

15341546
const (

driver_with_mockserver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5076,7 +5076,7 @@ func TestBeginReadWriteTransaction(t *testing.T) {
50765076
t.Fatalf("missing transaction for ExecuteSqlRequest")
50775077
}
50785078
if req.Transaction.GetId() == nil {
5079-
t.Fatalf("missing begin selector for ExecuteSqlRequest")
5079+
t.Fatalf("missing ID selector for ExecuteSqlRequest")
50805080
}
50815081
if g, w := req.RequestOptions.TransactionTag, tag; g != w {
50825082
t.Fatalf("transaction tag mismatch\n Got: %v\nWant: %v", g, w)

spannerlib/api/transaction_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"cloud.google.com/go/spanner/apiv1/spannerpb"
2525
"github.com/googleapis/go-sql-spanner/testutil"
2626
"google.golang.org/grpc/codes"
27+
"google.golang.org/grpc/status"
2728
)
2829

2930
func TestBeginAndCommit(t *testing.T) {
@@ -409,3 +410,89 @@ func TestDdlInTransaction(t *testing.T) {
409410
t.Fatalf("ClosePool returned unexpected error: %v", err)
410411
}
411412
}
413+
414+
func TestTransactionOptionsAsSqlStatements(t *testing.T) {
415+
t.Parallel()
416+
417+
ctx := context.Background()
418+
server, teardown := setupMockServer(t)
419+
defer teardown()
420+
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address)
421+
422+
poolId, err := CreatePool(ctx, dsn)
423+
if err != nil {
424+
t.Fatalf("CreatePool returned unexpected error: %v", err)
425+
}
426+
connId, err := CreateConnection(ctx, poolId)
427+
if err != nil {
428+
t.Fatalf("CreateConnection returned unexpected error: %v", err)
429+
}
430+
if err := BeginTransaction(ctx, poolId, connId, &spannerpb.TransactionOptions{}); err != nil {
431+
t.Fatalf("BeginTransaction returned unexpected error: %v", err)
432+
}
433+
434+
// Set some local transaction options.
435+
if rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: "set local transaction_tag = 'my_tag'"}); err != nil {
436+
t.Fatalf("setting transaction_tag returned unexpected error: %v", err)
437+
} else {
438+
_ = CloseRows(ctx, poolId, connId, rowsId)
439+
}
440+
if rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: "set local retry_aborts_internally = false"}); err != nil {
441+
t.Fatalf("setting retry_aborts_internally returned unexpected error: %v", err)
442+
} else {
443+
_ = CloseRows(ctx, poolId, connId, rowsId)
444+
}
445+
446+
// Execute a statement in the transaction.
447+
if rowsId, err := Execute(ctx, poolId, connId, &spannerpb.ExecuteSqlRequest{Sql: testutil.UpdateBarSetFoo}); err != nil {
448+
t.Fatalf("Execute returned unexpected error: %v", err)
449+
} else {
450+
_ = CloseRows(ctx, poolId, connId, rowsId)
451+
}
452+
453+
// Abort the transaction to verify that the retry_aborts_internally setting was respected.
454+
server.TestSpanner.PutExecutionTime(testutil.MethodCommitTransaction, testutil.SimulatedExecutionTime{
455+
Errors: []error{status.Error(codes.Aborted, "Aborted")},
456+
})
457+
458+
// Commit the transaction. This should fail with an Aborted error.
459+
if _, err := Commit(ctx, poolId, connId); err == nil {
460+
t.Fatal("missing expected error")
461+
} else {
462+
if g, w := spanner.ErrCode(err), codes.Aborted; g != w {
463+
t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w)
464+
}
465+
}
466+
467+
// Verify that the transaction_tag setting was respected.
468+
requests := server.TestSpanner.DrainRequestsFromServer()
469+
executeRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.ExecuteSqlRequest{}))
470+
if g, w := len(executeRequests), 1; g != w {
471+
t.Fatalf("Execute request count mismatch\n Got: %v\nWant: %v", g, w)
472+
}
473+
executeRequest := executeRequests[0].(*spannerpb.ExecuteSqlRequest)
474+
if executeRequest.RequestOptions == nil {
475+
t.Fatalf("Execute request options not set")
476+
}
477+
if g, w := executeRequest.RequestOptions.TransactionTag, "my_tag"; g != w {
478+
t.Fatalf("TransactionTag mismatch\n Got: %v\nWant: %v", g, w)
479+
}
480+
commitRequests := testutil.RequestsOfType(requests, reflect.TypeOf(&spannerpb.CommitRequest{}))
481+
if g, w := len(commitRequests), 1; g != w {
482+
t.Fatalf("Commit request count mismatch\n Got: %v\nWant: %v", g, w)
483+
}
484+
commitRequest := commitRequests[0].(*spannerpb.CommitRequest)
485+
if commitRequest.RequestOptions == nil {
486+
t.Fatalf("Commit request options not set")
487+
}
488+
if g, w := commitRequest.RequestOptions.TransactionTag, "my_tag"; g != w {
489+
t.Fatalf("TransactionTag mismatch\n Got: %v\nWant: %v", g, w)
490+
}
491+
492+
if err := CloseConnection(ctx, poolId, connId); err != nil {
493+
t.Fatalf("CloseConnection returned unexpected error: %v", err)
494+
}
495+
if err := ClosePool(ctx, poolId); err != nil {
496+
t.Fatalf("ClosePool returned unexpected error: %v", err)
497+
}
498+
}

0 commit comments

Comments
 (0)