Skip to content

Commit 029dc8a

Browse files
committed
chore: add transaction support for SpannerLib
1 parent f331cf4 commit 029dc8a

File tree

13 files changed

+1183
-6
lines changed

13 files changed

+1183
-6
lines changed

conn.go

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,34 @@ func (c *conn) getBatchReadOnlyTransactionOptions() BatchReadOnlyTransactionOpti
10711071
return BatchReadOnlyTransactionOptions{TimestampBound: c.ReadOnlyStaleness()}
10721072
}
10731073

1074+
// BeginReadOnlyTransaction is not part of the public API of the database/sql driver.
1075+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1076+
//
1077+
// BeginReadOnlyTransaction starts a new read-only transaction on this connection.
1078+
func (c *conn) BeginReadOnlyTransaction(ctx context.Context, options *ReadOnlyTransactionOptions) (driver.Tx, error) {
1079+
c.withTempReadOnlyTransactionOptions(options)
1080+
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true})
1081+
if err != nil {
1082+
c.withTempReadOnlyTransactionOptions(nil)
1083+
return nil, err
1084+
}
1085+
return tx, nil
1086+
}
1087+
1088+
// BeginReadWriteTransaction is not part of the public API of the database/sql driver.
1089+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1090+
//
1091+
// BeginReadWriteTransaction starts a new read/write transaction on this connection.
1092+
func (c *conn) BeginReadWriteTransaction(ctx context.Context, options *ReadWriteTransactionOptions) (driver.Tx, error) {
1093+
c.withTempTransactionOptions(options)
1094+
tx, err := c.BeginTx(ctx, driver.TxOptions{})
1095+
if err != nil {
1096+
c.withTempTransactionOptions(nil)
1097+
return nil, err
1098+
}
1099+
return tx, nil
1100+
}
1101+
10741102
func (c *conn) Begin() (driver.Tx, error) {
10751103
return c.BeginTx(context.Background(), driver.TxOptions{})
10761104
}
@@ -1254,18 +1282,29 @@ func (c *conn) inReadWriteTransaction() bool {
12541282
return false
12551283
}
12561284

1257-
func (c *conn) commit(ctx context.Context) (*spanner.CommitResponse, error) {
1285+
// Commit is not part of the public API of the database/sql driver.
1286+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1287+
//
1288+
// Commit commits the current transaction on this connection.
1289+
func (c *conn) Commit(ctx context.Context) (*spanner.CommitResponse, error) {
12581290
if !c.inTransaction() {
12591291
return nil, status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
12601292
}
12611293
// TODO: Pass in context to the tx.Commit() function.
12621294
if err := c.tx.Commit(); err != nil {
12631295
return nil, err
12641296
}
1265-
return c.CommitResponse()
1297+
1298+
// This will return either the commit response or nil, depending on whether the transaction was a
1299+
// read/write transaction or a read-only transaction.
1300+
return propertyCommitResponse.GetValueOrDefault(c.state), nil
12661301
}
12671302

1268-
func (c *conn) rollback(ctx context.Context) error {
1303+
// Rollback is not part of the public API of the database/sql driver.
1304+
// It is exported for internal reasons, and may receive breaking changes without prior notice.
1305+
//
1306+
// Rollback rollbacks the current transaction on this connection.
1307+
func (c *conn) Rollback(ctx context.Context) error {
12691308
if !c.inTransaction() {
12701309
return status.Errorf(codes.FailedPrecondition, "this connection does not have a transaction")
12711310
}

conn_with_mockserver_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,26 @@ func TestTwoTransactionsOnOneConn(t *testing.T) {
8282
}
8383
}
8484

85+
func TestTwoQueriesOnOneConn(t *testing.T) {
86+
t.Parallel()
87+
88+
db, _, teardown := setupTestDBConnection(t)
89+
defer teardown()
90+
ctx := context.Background()
91+
92+
c, _ := db.Conn(ctx)
93+
defer silentClose(c)
94+
95+
for range 2 {
96+
r, err := c.QueryContext(context.Background(), testutil.SelectFooFromBar)
97+
if err != nil {
98+
t.Fatal(err)
99+
}
100+
_ = r.Next()
101+
defer silentClose(r)
102+
}
103+
}
104+
85105
func TestExplicitBeginTx(t *testing.T) {
86106
t.Parallel()
87107

spannerlib/api/connection.go

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package api
1717
import (
1818
"context"
1919
"database/sql"
20+
"database/sql/driver"
2021
"fmt"
2122
"strings"
2223
"sync"
@@ -25,6 +26,9 @@ import (
2526
"cloud.google.com/go/spanner"
2627
"cloud.google.com/go/spanner/apiv1/spannerpb"
2728
spannerdriver "github.com/googleapis/go-sql-spanner"
29+
"google.golang.org/grpc/codes"
30+
"google.golang.org/grpc/status"
31+
"google.golang.org/protobuf/types/known/timestamppb"
2832
)
2933

3034
// CloseConnection looks up the connection with the given poolId and connId and closes it.
@@ -42,6 +46,35 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error {
4246
return conn.close(ctx)
4347
}
4448

49+
// BeginTransaction starts a new transaction on the given connection.
50+
// A connection can have at most one transaction at any time. This function therefore returns an error if the
51+
// connection has an active transaction.
52+
func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error {
53+
conn, err := findConnection(poolId, connId)
54+
if err != nil {
55+
return err
56+
}
57+
return conn.BeginTransaction(ctx, txOpts)
58+
}
59+
60+
// Commit commits the current transaction on the given connection.
61+
func Commit(ctx context.Context, poolId, connId int64) (*spannerpb.CommitResponse, error) {
62+
conn, err := findConnection(poolId, connId)
63+
if err != nil {
64+
return nil, err
65+
}
66+
return conn.commit(ctx)
67+
}
68+
69+
// Rollback rollbacks the current transaction on the given connection.
70+
func Rollback(ctx context.Context, poolId, connId int64) error {
71+
conn, err := findConnection(poolId, connId)
72+
if err != nil {
73+
return err
74+
}
75+
return conn.rollback(ctx)
76+
}
77+
4578
func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) {
4679
conn, err := findConnection(poolId, connId)
4780
if err != nil {
@@ -59,23 +92,141 @@ type Connection struct {
5992
backend *sql.Conn
6093
}
6194

95+
// spannerConn is an internal interface that contains the internal functions that are used by this API.
96+
// It is implemented by the spannerdriver.conn struct.
97+
type spannerConn interface {
98+
BeginReadOnlyTransaction(ctx context.Context, options *spannerdriver.ReadOnlyTransactionOptions) (driver.Tx, error)
99+
BeginReadWriteTransaction(ctx context.Context, options *spannerdriver.ReadWriteTransactionOptions) (driver.Tx, error)
100+
Commit(ctx context.Context) (*spanner.CommitResponse, error)
101+
Rollback(ctx context.Context) error
102+
}
103+
62104
type queryExecutor interface {
63105
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
64106
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
65107
}
66108

67109
func (conn *Connection) close(ctx context.Context) error {
68110
conn.closeResults(ctx)
111+
// Rollback any open transactions on the connection.
112+
_ = conn.rollback(ctx)
113+
69114
err := conn.backend.Close()
70115
if err != nil {
71116
return err
72117
}
73118
return nil
74119
}
75120

121+
func (conn *Connection) BeginTransaction(ctx context.Context, txOpts *spannerpb.TransactionOptions) error {
122+
var err error
123+
if txOpts.GetReadOnly() != nil {
124+
return conn.beginReadOnlyTransaction(ctx, convertToReadOnlyOpts(txOpts))
125+
} else if txOpts.GetPartitionedDml() != nil {
126+
err = spanner.ToSpannerError(status.Error(codes.InvalidArgument, "transaction type not supported"))
127+
} else {
128+
return conn.beginReadWriteTransaction(ctx, convertToReadWriteTransactionOptions(txOpts))
129+
}
130+
if err != nil {
131+
return err
132+
}
133+
return nil
134+
}
135+
136+
func (conn *Connection) beginReadOnlyTransaction(ctx context.Context, opts *spannerdriver.ReadOnlyTransactionOptions) error {
137+
return conn.backend.Raw(func(driverConn any) (err error) {
138+
sc, _ := driverConn.(spannerConn)
139+
_, err = sc.BeginReadOnlyTransaction(ctx, opts)
140+
return err
141+
})
142+
}
143+
144+
func (conn *Connection) beginReadWriteTransaction(ctx context.Context, opts *spannerdriver.ReadWriteTransactionOptions) error {
145+
return conn.backend.Raw(func(driverConn any) (err error) {
146+
sc, _ := driverConn.(spannerConn)
147+
_, err = sc.BeginReadWriteTransaction(ctx, opts)
148+
return err
149+
})
150+
}
151+
152+
func (conn *Connection) commit(ctx context.Context) (*spannerpb.CommitResponse, error) {
153+
var response *spanner.CommitResponse
154+
if err := conn.backend.Raw(func(driverConn any) (err error) {
155+
spannerConn, _ := driverConn.(spannerConn)
156+
response, err = spannerConn.Commit(ctx)
157+
if err != nil {
158+
return err
159+
}
160+
return nil
161+
}); err != nil {
162+
return nil, err
163+
}
164+
165+
// The commit response is nil for read-only transactions.
166+
if response == nil {
167+
return nil, nil
168+
}
169+
// TODO: Include commit stats
170+
return &spannerpb.CommitResponse{CommitTimestamp: timestamppb.New(response.CommitTs)}, nil
171+
}
172+
173+
func (conn *Connection) rollback(ctx context.Context) error {
174+
return conn.backend.Raw(func(driverConn any) (err error) {
175+
spannerConn, _ := driverConn.(spannerConn)
176+
return spannerConn.Rollback(ctx)
177+
})
178+
}
179+
180+
func convertToReadOnlyOpts(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadOnlyTransactionOptions {
181+
return &spannerdriver.ReadOnlyTransactionOptions{
182+
TimestampBound: convertTimestampBound(txOpts),
183+
}
184+
}
185+
186+
func convertTimestampBound(txOpts *spannerpb.TransactionOptions) spanner.TimestampBound {
187+
ro := txOpts.GetReadOnly()
188+
if ro.GetStrong() {
189+
return spanner.StrongRead()
190+
} else if ro.GetReadTimestamp() != nil {
191+
return spanner.ReadTimestamp(ro.GetReadTimestamp().AsTime())
192+
} else if ro.GetMinReadTimestamp() != nil {
193+
return spanner.ReadTimestamp(ro.GetMinReadTimestamp().AsTime())
194+
} else if ro.GetExactStaleness() != nil {
195+
return spanner.ExactStaleness(ro.GetExactStaleness().AsDuration())
196+
} else if ro.GetMaxStaleness() != nil {
197+
return spanner.MaxStaleness(ro.GetMaxStaleness().AsDuration())
198+
}
199+
return spanner.TimestampBound{}
200+
}
201+
202+
func convertToReadWriteTransactionOptions(txOpts *spannerpb.TransactionOptions) *spannerdriver.ReadWriteTransactionOptions {
203+
readLockMode := spannerpb.TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED
204+
if txOpts.GetReadWrite() != nil {
205+
readLockMode = txOpts.GetReadWrite().GetReadLockMode()
206+
}
207+
return &spannerdriver.ReadWriteTransactionOptions{
208+
TransactionOptions: spanner.TransactionOptions{
209+
IsolationLevel: txOpts.GetIsolationLevel(),
210+
ReadLockMode: readLockMode,
211+
},
212+
}
213+
}
214+
215+
func convertIsolationLevel(level spannerpb.TransactionOptions_IsolationLevel) sql.IsolationLevel {
216+
switch level {
217+
case spannerpb.TransactionOptions_SERIALIZABLE:
218+
return sql.LevelSerializable
219+
case spannerpb.TransactionOptions_REPEATABLE_READ:
220+
return sql.LevelRepeatableRead
221+
}
222+
return sql.LevelDefault
223+
}
224+
76225
func (conn *Connection) closeResults(ctx context.Context) {
77226
conn.results.Range(func(key, value interface{}) bool {
78-
// TODO: Implement
227+
if r, ok := value.(*rows); ok {
228+
_ = r.Close(ctx)
229+
}
79230
return true
80231
})
81232
}

0 commit comments

Comments
 (0)