@@ -17,6 +17,7 @@ package api
1717import  (
1818	"context" 
1919	"database/sql" 
20+ 	"database/sql/driver" 
2021	"fmt" 
2122	"strings" 
2223	"sync" 
@@ -25,6 +26,9 @@ import (
2526	"cloud.google.com/go/spanner" 
2627	"cloud.google.com/go/spanner/apiv1/spannerpb" 
2728	spannerdriver "github.com/googleapis/go-sql-spanner" 
29+ 	"google.golang.org/grpc/codes" 
30+ 	"google.golang.org/grpc/status" 
31+ 	"google.golang.org/protobuf/types/known/timestamppb" 
2832)
2933
3034// CloseConnection looks up the connection with the given poolId and connId and closes it. 
@@ -42,6 +46,35 @@ func CloseConnection(ctx context.Context, poolId, connId int64) error {
4246	return  conn .close (ctx )
4347}
4448
49+ // BeginTransaction starts a new transaction on the given connection. 
50+ // A connection can have at most one transaction at any time. This function therefore returns an error if the 
51+ // connection has an active transaction. 
52+ func  BeginTransaction (ctx  context.Context , poolId , connId  int64 , txOpts  * spannerpb.TransactionOptions ) error  {
53+ 	conn , err  :=  findConnection (poolId , connId )
54+ 	if  err  !=  nil  {
55+ 		return  err 
56+ 	}
57+ 	return  conn .BeginTransaction (ctx , txOpts )
58+ }
59+ 
60+ // Commit commits the current transaction on the given connection. 
61+ func  Commit (ctx  context.Context , poolId , connId  int64 ) (* spannerpb.CommitResponse , error ) {
62+ 	conn , err  :=  findConnection (poolId , connId )
63+ 	if  err  !=  nil  {
64+ 		return  nil , err 
65+ 	}
66+ 	return  conn .commit (ctx )
67+ }
68+ 
69+ // Rollback rollbacks the current transaction on the given connection. 
70+ func  Rollback (ctx  context.Context , poolId , connId  int64 ) error  {
71+ 	conn , err  :=  findConnection (poolId , connId )
72+ 	if  err  !=  nil  {
73+ 		return  err 
74+ 	}
75+ 	return  conn .rollback (ctx )
76+ }
77+ 
4578func  Execute (ctx  context.Context , poolId , connId  int64 , executeSqlRequest  * spannerpb.ExecuteSqlRequest ) (int64 , error ) {
4679	conn , err  :=  findConnection (poolId , connId )
4780	if  err  !=  nil  {
@@ -59,23 +92,141 @@ type Connection struct {
5992	backend  * sql.Conn 
6093}
6194
95+ // spannerConn is an internal interface that contains the internal functions that are used by this API. 
96+ // It is implemented by the spannerdriver.conn struct. 
97+ type  spannerConn  interface  {
98+ 	BeginReadOnlyTransaction (ctx  context.Context , options  * spannerdriver.ReadOnlyTransactionOptions ) (driver.Tx , error )
99+ 	BeginReadWriteTransaction (ctx  context.Context , options  * spannerdriver.ReadWriteTransactionOptions ) (driver.Tx , error )
100+ 	Commit (ctx  context.Context ) (* spanner.CommitResponse , error )
101+ 	Rollback (ctx  context.Context ) error 
102+ }
103+ 
62104type  queryExecutor  interface  {
63105	ExecContext (ctx  context.Context , query  string , args  ... any ) (sql.Result , error )
64106	QueryContext (ctx  context.Context , query  string , args  ... any ) (* sql.Rows , error )
65107}
66108
67109func  (conn  * Connection ) close (ctx  context.Context ) error  {
68110	conn .closeResults (ctx )
111+ 	// Rollback any open transactions on the connection. 
112+ 	_  =  conn .rollback (ctx )
113+ 
69114	err  :=  conn .backend .Close ()
70115	if  err  !=  nil  {
71116		return  err 
72117	}
73118	return  nil 
74119}
75120
121+ func  (conn  * Connection ) BeginTransaction (ctx  context.Context , txOpts  * spannerpb.TransactionOptions ) error  {
122+ 	var  err  error 
123+ 	if  txOpts .GetReadOnly () !=  nil  {
124+ 		return  conn .beginReadOnlyTransaction (ctx , convertToReadOnlyOpts (txOpts ))
125+ 	} else  if  txOpts .GetPartitionedDml () !=  nil  {
126+ 		err  =  spanner .ToSpannerError (status .Error (codes .InvalidArgument , "transaction type not supported" ))
127+ 	} else  {
128+ 		return  conn .beginReadWriteTransaction (ctx , convertToReadWriteTransactionOptions (txOpts ))
129+ 	}
130+ 	if  err  !=  nil  {
131+ 		return  err 
132+ 	}
133+ 	return  nil 
134+ }
135+ 
136+ func  (conn  * Connection ) beginReadOnlyTransaction (ctx  context.Context , opts  * spannerdriver.ReadOnlyTransactionOptions ) error  {
137+ 	return  conn .backend .Raw (func (driverConn  any ) (err  error ) {
138+ 		sc , _  :=  driverConn .(spannerConn )
139+ 		_ , err  =  sc .BeginReadOnlyTransaction (ctx , opts )
140+ 		return  err 
141+ 	})
142+ }
143+ 
144+ func  (conn  * Connection ) beginReadWriteTransaction (ctx  context.Context , opts  * spannerdriver.ReadWriteTransactionOptions ) error  {
145+ 	return  conn .backend .Raw (func (driverConn  any ) (err  error ) {
146+ 		sc , _  :=  driverConn .(spannerConn )
147+ 		_ , err  =  sc .BeginReadWriteTransaction (ctx , opts )
148+ 		return  err 
149+ 	})
150+ }
151+ 
152+ func  (conn  * Connection ) commit (ctx  context.Context ) (* spannerpb.CommitResponse , error ) {
153+ 	var  response  * spanner.CommitResponse 
154+ 	if  err  :=  conn .backend .Raw (func (driverConn  any ) (err  error ) {
155+ 		spannerConn , _  :=  driverConn .(spannerConn )
156+ 		response , err  =  spannerConn .Commit (ctx )
157+ 		if  err  !=  nil  {
158+ 			return  err 
159+ 		}
160+ 		return  nil 
161+ 	}); err  !=  nil  {
162+ 		return  nil , err 
163+ 	}
164+ 
165+ 	// The commit response is nil for read-only transactions. 
166+ 	if  response  ==  nil  {
167+ 		return  nil , nil 
168+ 	}
169+ 	// TODO: Include commit stats 
170+ 	return  & spannerpb.CommitResponse {CommitTimestamp : timestamppb .New (response .CommitTs )}, nil 
171+ }
172+ 
173+ func  (conn  * Connection ) rollback (ctx  context.Context ) error  {
174+ 	return  conn .backend .Raw (func (driverConn  any ) (err  error ) {
175+ 		spannerConn , _  :=  driverConn .(spannerConn )
176+ 		return  spannerConn .Rollback (ctx )
177+ 	})
178+ }
179+ 
180+ func  convertToReadOnlyOpts (txOpts  * spannerpb.TransactionOptions ) * spannerdriver.ReadOnlyTransactionOptions  {
181+ 	return  & spannerdriver.ReadOnlyTransactionOptions {
182+ 		TimestampBound : convertTimestampBound (txOpts ),
183+ 	}
184+ }
185+ 
186+ func  convertTimestampBound (txOpts  * spannerpb.TransactionOptions ) spanner.TimestampBound  {
187+ 	ro  :=  txOpts .GetReadOnly ()
188+ 	if  ro .GetStrong () {
189+ 		return  spanner .StrongRead ()
190+ 	} else  if  ro .GetReadTimestamp () !=  nil  {
191+ 		return  spanner .ReadTimestamp (ro .GetReadTimestamp ().AsTime ())
192+ 	} else  if  ro .GetMinReadTimestamp () !=  nil  {
193+ 		return  spanner .ReadTimestamp (ro .GetMinReadTimestamp ().AsTime ())
194+ 	} else  if  ro .GetExactStaleness () !=  nil  {
195+ 		return  spanner .ExactStaleness (ro .GetExactStaleness ().AsDuration ())
196+ 	} else  if  ro .GetMaxStaleness () !=  nil  {
197+ 		return  spanner .MaxStaleness (ro .GetMaxStaleness ().AsDuration ())
198+ 	}
199+ 	return  spanner.TimestampBound {}
200+ }
201+ 
202+ func  convertToReadWriteTransactionOptions (txOpts  * spannerpb.TransactionOptions ) * spannerdriver.ReadWriteTransactionOptions  {
203+ 	readLockMode  :=  spannerpb .TransactionOptions_ReadWrite_READ_LOCK_MODE_UNSPECIFIED 
204+ 	if  txOpts .GetReadWrite () !=  nil  {
205+ 		readLockMode  =  txOpts .GetReadWrite ().GetReadLockMode ()
206+ 	}
207+ 	return  & spannerdriver.ReadWriteTransactionOptions {
208+ 		TransactionOptions : spanner.TransactionOptions {
209+ 			IsolationLevel : txOpts .GetIsolationLevel (),
210+ 			ReadLockMode :   readLockMode ,
211+ 		},
212+ 	}
213+ }
214+ 
215+ func  convertIsolationLevel (level  spannerpb.TransactionOptions_IsolationLevel ) sql.IsolationLevel  {
216+ 	switch  level  {
217+ 	case  spannerpb .TransactionOptions_SERIALIZABLE :
218+ 		return  sql .LevelSerializable 
219+ 	case  spannerpb .TransactionOptions_REPEATABLE_READ :
220+ 		return  sql .LevelRepeatableRead 
221+ 	}
222+ 	return  sql .LevelDefault 
223+ }
224+ 
76225func  (conn  * Connection ) closeResults (ctx  context.Context ) {
77226	conn .results .Range (func (key , value  interface {}) bool  {
78- 		// TODO: Implement 
227+ 		if  r , ok  :=  value .(* rows ); ok  {
228+ 			_  =  r .Close (ctx )
229+ 		}
79230		return  true 
80231	})
81232}
0 commit comments